support cmake download url: mysql, ssl, nghttps

update buffer
This commit is contained in:
jarodruan 2020-02-11 17:50:24 +08:00
parent bfd552b17d
commit 8693c840b5
19 changed files with 510 additions and 640 deletions

View File

@ -24,7 +24,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(TARS_VERSION "2.0.0")
add_definitions(-DTARS_VERSION="${TARS_VERSION}")
set(TARS_SSL 0)
set(TARS_SSL 1)
add_definitions(-DTARS_SSL=${TARS_SSL})
set(TARS_HTTP2 1)
add_definitions(-DTARS_HTTP2=${TARS_HTTP2})
@ -118,7 +118,7 @@ set(NGHTTP2_DIR_LIB "${THIRDPARTY_PATH}/nghttp2-lib/lib")
include_directories(${NGHTTP2_DIR_INC})
link_directories(${NGHTTP2_DIR_LIB})
set(SSL_DIR_INC "${THIRDPARTY_PATH}/openssl-lib/include/openssl")
set(SSL_DIR_INC "${THIRDPARTY_PATH}/openssl-lib/include/")
set(SSL_DIR_LIB "${THIRDPARTY_PATH}/openssl-lib")
include_directories(${SSL_DIR_INC})
link_directories(${SSL_DIR_LIB})
@ -126,16 +126,19 @@ link_directories(${SSL_DIR_LIB})
set(LIB_MYSQL)
set(LIB_NGHTTP2)
set(LIB_SSL)
set(LIB_CRYPTO)
IF (WIN32)
set(LIB_MYSQL "libmysql")
set(LIB_NGHTTP2 "libnghttp2_static")
set(LIB_SSL "libssl")
set(LIB_CRYPTO "libcrypto")
ELSE()
link_libraries(pthread dl)
set(LIB_MYSQL "mysqlclient")
set(LIB_NGHTTP2 "nghttp2_static")
set(LIB_SSL "ssl")
set(LIB_CRYPTO "crypto")
ENDIF()
link_libraries(${LIB_MYSQL})
@ -146,6 +149,7 @@ endif()
if(TARS_SSL)
link_libraries(${LIB_SSL})
link_libraries(${LIB_CRYPTO})
endif()
#-------------------------------------------------------------

View File

@ -54,7 +54,7 @@ public:
string buffer = response.encode();
shared_ptr<TC_EpollServer::SendContext> send = data->createSendContext();
send->buffer().assign(buffer.c_str(), buffer.c_str() + buffer.size());
send->buffer()->assign(buffer.c_str(), buffer.size());
sendResponse(send);
@ -141,7 +141,7 @@ public:
cout << "SocketHandle::handle : " << data->ip() << ":" << data->port() << endl;
shared_ptr<TC_EpollServer::SendContext> send = data->createSendContext();
send->buffer() = data->buffer();
send->buffer()->setBuffer(data->buffer());
sendResponse(send);
}

View File

@ -164,7 +164,7 @@ int AdapterProxy::invoke(ReqMessage * msg)
msg->sReqData->setBuffer(_objectProxy->getProxyProtocol().requestFunc(msg->request, _trans.get()));
//交给连接发送数据,连接连上,buffer不为空,直接发送数据成功
//链表是空的, 则直接发送当前这条数据, 如果链表非空或者发送失败了, 则放到队列中, 等待下次发送
if(_timeoutQueue->sendListEmpty() && _trans->sendRequest(msg->sReqData) != Transceiver::eRetError)
{
TLOGTARS("[TARS][AdapterProxy::invoke push (send) objname:" << _objectProxy->name() << ",desc:" << _endpoint.desc() << ",id:" << msg->request.iRequestId << endl);

View File

@ -32,7 +32,7 @@ namespace tars
{
// //TAFServer的协议解析器
// TC_NetWorkBuffer::PACKET_TYPE AppProtocol::parseAdmin(TC_NetWorkBuffer &in, shared_ptr<TC_NetWorkBuffer::SendBuffer> &out)
// TC_NetWorkBuffer::PACKET_TYPE AppProtocol::parseAdmin(TC_NetWorkBuffer &in, shared_ptr<TC_NetWorkBuffer::Buffer> &out)
// {
// return parse(in, out->getBuffer());
// }
@ -216,9 +216,9 @@ vector<char> ProxyProtocol::http2Request(RequestPacket& request, Transceiver *tr
TLOGERROR("[TARS]http2Request::Fatal error: nghttp2_session_send return: " << nghttp2_strerror(rv) << endl);
return vector<char>();
}
// cout << "nghttp2_session_send, id:" << request.iRequestId << ", buff size:" << session->sendBuffer().size() << endl;
// cout << "nghttp2_session_send, id:" << request.iRequestId << ", buff size:" << session->_buffer().size() << endl;
// if(session->sendBuffer().empty())
// if(session->_buffer().empty())
// {
// exit(0);
// }

View File

@ -616,7 +616,7 @@ void Application::main(const TC_Option &option)
try
{
#if TARS_SSL
SSLManager::GlobalInit();
TC_SSLManager::GlobalInit();
#endif
#if TARGET_PLATFORM_LINUX || TARGET_PLATFORM_IOS
TC_Common::ignorePipe();
@ -840,7 +840,7 @@ void Application::initializeClient()
string key = path + _conf.get("/tars/application/clientssl/<key>");
if (key == path) key.clear();
if (!SSLManager::getInstance()->AddCtx("client", ca, cert, key, false))
if (!TC_SSLManager::getInstance()->addCtx("client", ca, cert, key, false))
cout << "failed add client cert " << ca << endl;
else
cout << "succ add client cert " << ca << endl;
@ -1153,7 +1153,7 @@ void Application::initializeServer()
string key = path + _conf.get("/tars/application/serverssl/<key>");
bool verifyClient = (_conf.get("/tars/application/serverssl/<verifyclient>", "0") == "0") ? false : true;
if (!SSLManager::getInstance()->AddCtx("server", ca, cert, key, verifyClient))
if (!TC_SSLManager::getInstance()->addCtx("server", ca, cert, key, verifyClient))
cout << "failed add server cert " << ca << endl;
else
cout << "succ add server cert " << ca << ", verifyClient " << verifyClient << endl;

View File

@ -105,15 +105,15 @@ bool processAuth(TC_EpollServer::Connection *conn, const shared_ptr<TC_EpollServ
iHeaderLen = htonl((int)(os.getLength()));
sData->buffer().swap(os.getByteBuffer());
sData->buffer()->swap(os.getByteBuffer());
//重写头4个字节
memcpy(sData->buffer().data(), (const char *)&iHeaderLen, sizeof(iHeaderLen));
memcpy(sData->buffer()->buffer(), (const char *)&iHeaderLen, sizeof(iHeaderLen));
}
else
{
sData->buffer().assign(out.begin(), out.end());
sData->buffer()->assign(out.c_str(), out.size());
}
adapter->getEpollServer()->send(sData);

View File

@ -253,7 +253,7 @@ void TarsCurrent::sendResponse(const char* buff, uint32_t len)
{
// _servantHandle->sendResponse(_uid, string(buff, len), _ip, _port, _fd);
shared_ptr<TC_EpollServer::SendContext> send = _data->createSendContext();
send->buffer().assign(buff, buff + len);
send->buffer()->assign(buff, len);
_servantHandle->sendResponse(send);
}
@ -378,13 +378,13 @@ void TarsCurrent::sendResponse(int iRet, const vector<char> &buffer, const map<
response.writeTo(os);
}
os.swap(send->buffer());
assert(send->buffer()->length() >= 4);
assert(send->buffer().size() >= 4);
iHeaderLen = htonl((int)(send->buffer()->length()));
iHeaderLen = htonl((int)(send->buffer().size()));
memcpy(os.getByteBuffer().data(), (const char *)&iHeaderLen, sizeof(iHeaderLen));
memcpy(&send->buffer()[0], (const char *)&iHeaderLen, sizeof(iHeaderLen));
send->buffer()->swap(os.getByteBuffer());
_servantHandle->sendResponse(send);

View File

@ -42,6 +42,7 @@ Transceiver::Transceiver(AdapterProxy * pAdapterProxy,const EndpointInfo &ep)
, _connStatus(eUnconnected)
, _conTimeoutTime(0)
, _authState(AUTH_INIT)
, _sendBuffer(this)
, _recvBuffer(this)
{
_fdInfo.iType = FDInfo::ET_C_NET;
@ -172,7 +173,7 @@ void Transceiver::_onConnect()
if (isSSL())
{
// 分配ssl对象
SSL* ssl = NewSSL("client");
SSL* ssl = TC_SSLManager::getInstance()->newSSL("client");
if (!ssl)
{
ObjectProxy* obj = _adapterProxy->getObjProxy();
@ -183,21 +184,21 @@ void Transceiver::_onConnect()
_openssl.reset(new TC_OpenSSL());
_openssl->Init(ssl, false);
std::string out = _openssl->DoHandshake();
if (_openssl->HasError())
int ret = _openssl->DoHandshake(_sendBuffer);
if (ret != 0)
{
TLOGERROR("[TARS] SSL_connect failed " << endl);
this->close();
return;
}
_sendBuffer.addBuffer(out);
// _sendBuffer.addBuffer(out);
// send the encrypt data from write buffer
if (!out.empty())
if (!_sendBuffer.empty())
{
// this->sendRequest(out.data(), out.size(), true);
this->sendRequest(_sendBuffer);
this->doRequest();
// this->sendRequest(_sendBuffer);
}
return;
}
@ -247,13 +248,13 @@ bool Transceiver::sendAuthData(const BasicAuthInfo& info)
request.iMessageType = kAuthType;
request.sBuffer.assign(out.begin(), out.end());
// vector<char> toSend;
_sendBuffer->addBuffer(objPrx->getProxyProtocol().requestFunc(request, this));
_sendBuffer.addBuffer(objPrx->getProxyProtocol().requestFunc(request, this));
// _sendBuffer.addBuffer(toSend);
// if (sendRequest(toSend.data(), toSend.size(), true) == eRetError)
if (sendRequest(_sendBuffer, true) == eRetError)
// if (sendRequest(_sendBuffer, true) == eRetError)
int ret = doRequest();
if (ret != 0)
{
TLOGERROR("[TARS][Transceiver::setConnected failed sendRequest for Auth\n");
close();
@ -291,7 +292,7 @@ void Transceiver::close()
_fd = -1;
_sendBuffer.reset();
_sendBuffer.clearBuffers();
_recvBuffer.clearBuffers();
@ -324,20 +325,24 @@ int Transceiver::doRequest()
if(!isValid()) return -1;
//buf不为空,先发送buffer的内容
if(_sendBuffer && !_sendBuffer->empty())
while(!_sendBuffer.empty())
{
int iRet = this->send(_sendBuffer->buffer(), (uint32_t) _sendBuffer->length(), 0);
auto data = _sendBuffer.getBufferPointer();
assert(data.first != NULL && data.second != 0);
int iRet = this->send(data.first, (uint32_t) data.second, 0);
if (iRet < 0)
{
return -1;
}
_sendBuffer->add(iRet);
_sendBuffer.moveHeader(iRet);
// _sendBuffer->add(iRet);
}
//取adapter里面积攒的数据
if(!_sendBuffer || _sendBuffer->empty()) {
if(_sendBuffer.empty()) {
_adapterProxy->doInvoke();
}
@ -347,7 +352,7 @@ int Transceiver::doRequest()
return 0;
}
int Transceiver::sendRequest(const shared_ptr<TC_NetWorkBuffer::SendBuffer> &buff, bool forceSend)
int Transceiver::sendRequest(const shared_ptr<TC_NetWorkBuffer::Buffer> &buff, bool forceSend)
{
//空数据 直接返回成功
if(buff->empty())
@ -369,9 +374,9 @@ int Transceiver::sendRequest(const shared_ptr<TC_NetWorkBuffer::SendBuffer> &buf
return eRetError; // 需要鉴权但还没通过,不能发送非认证消息
}
//buf不为空,直接返回失败
//buf不为空, 表示之前的数据还没发送完, 直接返回失败
//等buffer可写了,epoll会通知写事件
if(_sendBuffer && !_sendBuffer->empty())
if(!_sendBuffer.empty())
return eRetError;
int iRet = this->send(buff->buffer(), (uint32_t)buff->length(), 0);
@ -383,15 +388,17 @@ int Transceiver::sendRequest(const shared_ptr<TC_NetWorkBuffer::SendBuffer> &buf
//没有全部发送完,写buffer 返回成功
if(iRet < (int)buff->length())
{
_sendBuffer = buff;
_sendBuffer->add(iRet);
buff->add(iRet);
_sendBuffer.addBuffer(buff);
// _sendBuffer = buff;
// _sendBuffer->add(iRet);
return eRetFull;
}
else
{
//全部发送完毕了
_sendBuffer.reset();
}
// else
// {
// //全部发送完毕了
// _sendBuffer.reset();
// }
return eRetOk;

View File

@ -233,7 +233,7 @@ struct ReqMessage : public TC_HandleBase
response = std::make_shared<ResponsePacket>();
// sReqData.clear();
sReqData = std::make_shared<TC_NetWorkBuffer::SendBuffer>();
sReqData = std::make_shared<TC_NetWorkBuffer::Buffer>();
pMonitor = NULL;
bMonitorFin = false;
@ -264,10 +264,9 @@ struct ReqMessage : public TC_HandleBase
ObjectProxy * pObjectProxy; //调用端的proxy对象
RequestPacket request; //请求消息体
// ResponsePacket response; //响应消息体
shared_ptr<ResponsePacket> response; //响应消息体
// string sReqData; //请求消息体
shared_ptr<TC_NetWorkBuffer::SendBuffer> sReqData; //请求消息体
shared_ptr<TC_NetWorkBuffer::Buffer> sReqData; //请求消息体
ReqMonitor * pMonitor; //用于同步的monitor
bool bMonitorFin; //同步请求timewait是否结束

View File

@ -119,7 +119,7 @@ public:
* fd缓冲区已满,
* ,
*/
int sendRequest(const shared_ptr<TC_NetWorkBuffer::SendBuffer> &pData, bool forceSend = false);
int sendRequest(const shared_ptr<TC_NetWorkBuffer::Buffer> &pData, bool forceSend = false);
/*
* Send BufferCache是否有完整的包
@ -285,7 +285,9 @@ protected:
/*
* buffer
*/
shared_ptr<TC_NetWorkBuffer::SendBuffer> _sendBuffer;
// shared_ptr<TC_NetWorkBuffer::Buffer> _sendBuffer;
TC_NetWorkBuffer _sendBuffer;
/*
* buffer

View File

@ -123,39 +123,6 @@ public:
class Handle;
typedef TC_AutoPtr<Handle> HandlePtr;
// class HandleGroup;
// typedef TC_AutoPtr<HandleGroup> HandleGroupPtr;
// ////////////////////////////////////////////////////////////////////////////
// /**定义数据队列中的结构*/
// struct tagRecvData
// {
// uint32_t uid; /**连接标示*/
// string buffer; /**需要发送的内容*/
// string ip; /**远程连接的ip*/
// uint16_t port; /**远程连接的端口*/
// int64_t recvTimeStamp; /**接收到数据的时间*/
// bool isOverload; /**是否已过载 */
// bool isClosed; /**是否已关闭*/
// int fd; /*保存产生该消息的fd用于回包时选择网络线程*/
// BindAdapterPtr adapter; /**标识哪一个adapter的消息*/
// int closeType; /*如果是关闭消息包,则标识关闭类型,0:表示客户端主动关闭1:服务端主动关闭;2:连接超时服务端主动关闭*/
// };
// struct tagSendData
// {
// char cmd; /**命令:'c',关闭fd; 's',有数据需要发送*/
// uint32_t uid; /**连接标示*/
// string buffer; /**需要发送的内容*/
// string ip; /**远程连接的ip*/
// uint16_t port; /**远程连接的端口*/
// };
// typedef TC_ThreadQueue<tagRecvData*, deque<tagRecvData*> > recv_queue;
// typedef TC_ThreadQueue<tagSendData*, deque<tagSendData*> > send_queue;
// typedef recv_queue::queue_type recv_queue_type;
class RecvContext;
/**
*
@ -167,8 +134,8 @@ public:
SendContext(const shared_ptr<RecvContext> &context, char cmd) : _context(context), _cmd(cmd) {}
const shared_ptr<RecvContext> &getRecvContext() { return _context; }
vector<char> &buffer() { return _sbuffer; }
const vector<char> &buffer() const { return _sbuffer; }
const shared_ptr<TC_NetWorkBuffer::Buffer> & buffer() { return _sbuffer; }
// const vector<char> &buffer() const { return _sbuffer; }
char cmd() const { return _cmd; }
uint32_t uid() const { return _context->uid(); }
int fd() const { return _context->fd(); }
@ -180,7 +147,7 @@ public:
protected:
shared_ptr<RecvContext> _context;
char _cmd; /**send包才有效, 命令:'c',关闭fd; 's',有数据需要发送*/
vector<char> _sbuffer; /**发送的内容*/
shared_ptr<TC_NetWorkBuffer::Buffer> _sbuffer; /**发送的内容*/
};
////////////////////////////////////////////////////////////////////////////
@ -244,19 +211,6 @@ public:
int iLastRefreshTime;
};
////////////////////////////////////////////////////////////////////////////
/**
* name对handle分组
* handle处理一个或多个Adapter消息
* handle对象一个线程
*/
// struct HandleGroup : public TC_HandleBase
// {
// string name;
// TC_ThreadLock monitor;
// vector<HandlePtr> handles;
// map<string, BindAdapterPtr> adapters;
// };
////////////////////////////////////////////////////////////////////////////
/**
* @brief
*
@ -289,18 +243,6 @@ public:
*/
TC_EpollServer* getEpollServer();
// /**
// * 设置所属的Group
// * @param pHandleGroup
// */
// void setHandleGroup(HandleGroupPtr& pHandleGroup);
// /**
// * 获取所属Group
// * @return HandleGroup*
// */
// HandleGroupPtr& getHandleGroup();
/**
* Handle的索引(0~handle个数-1)
* @return
@ -334,14 +276,12 @@ public:
* @param stRecvData
* @param sSendBuffer
*/
// void sendResponse(unsigned int uid, const string &sSendBuffer, const string &ip, int port, int fd);
void sendResponse(const shared_ptr<SendContext> &data);
/**
*
* @param stRecvData
*/
// void close(unsigned int uid, int fd);
void close(const shared_ptr<RecvContext> &data);
/**
@ -371,36 +311,6 @@ public:
*/
virtual void handleImp();
// /**
// * 处理函数
// * @param stRecvData: 接收到的数据
// */
// virtual void handle(const tagRecvData &stRecvData) = 0;
// /**
// * 处理超时数据, 即数据在队列中的时间已经超过
// * 默认直接关闭连接
// * @param stRecvData: 接收到的数据
// */
// virtual void handleTimeout(const tagRecvData &stRecvData);
// /**
// * 处理连接关闭通知,包括
// * 1.close by peer
// * 2.recv/send fail
// * 3.close by timeout or overload
// * @param stRecvData:
// */
// virtual void handleClose(const tagRecvData &stRecvData);
// /**
// * 处理overload数据 即数据队列中长度已经超过允许值
// * 默认直接关闭连接
// * @param stRecvData: 接收到的数据
// */
// virtual void handleOverload(const tagRecvData &stRecvData);
/**
*
* @param stRecvData:
@ -466,13 +376,6 @@ public:
*/
virtual bool allFilterIsEmpty();
// /**
// * 设置服务
// * @param pEpollServer
// */
// void setEpollServer(TC_EpollServer *pEpollServer);
/**
* Adapter
* @param pEpollServer
@ -840,26 +743,6 @@ public:
*/
bool waitForRecvQueue(uint32_t handleIndex, shared_ptr<RecvContext> &recv);
// /**
// * 增加数据到队列中
// * @param vtRecvData
// * @param bPushBack 后端插入
// * @param sBuffer
// */
// void insertRecvQueue(const recv_queue::queue_type &vtRecvData,bool bPushBack = true);
// /**
// * 通知等待在接收队列上面的线程醒过来
// */
// void notifyRecvQueue();
// /**
// * 等待数据
// * @return bool
// */
// bool waitForRecvQueue(tagRecvData* &recv, uint32_t iWaitTime);
/**
*
* @return size_t
@ -882,49 +765,11 @@ public:
*/
static TC_NetWorkBuffer::PACKET_TYPE echo_header_filter(TC_NetWorkBuffer::PACKET_TYPE i, vector<char> &o);
// /**
// * 默认的协议解析类, 直接echo
// * @param r
// * @param o
// * @return int
// */
// static int echo_protocol(string &r, string &o);
// /**
// * 默认的包头处理
// * @param i
// * @param o
// * @return int
// */
// static int echo_header_filter(int i, string &o);
/**
*
*/
int getHeaderFilterLen();
/**
* handle组名
* @param handleGroupName
*/
// void setHandleGroupName(const string& handleGroupName);
/**
* handle组名
* @return string
*/
// string getHandleGroupName() const;
/**
* handle
* @return HandleGroupPtr
*/
// HandleGroupPtr getHandleGroup() const
// {
// return _handleGroup;
// }
/**
* ServantHandle数目
* @param n
@ -937,16 +782,6 @@ public:
*/
int getHandleNum();
// /**
// * 绑定两个Adapter到同一个Group
// * @param otherAdapter
// */
// void setHandle(BindAdapterPtr& otherAdapter)
// {
// _pEpollServer->setHandleGroup(otherAdapter->getHandleGroupName(), this);
// }
/**
* 线,线
*/
@ -1252,11 +1087,6 @@ public:
*/
Connection(BindAdapter *pBindAdapter, int fd);
/**
*
*/
// Connection(BindAdapter *pBindAdapter);
/**
*
*/
@ -1266,7 +1096,6 @@ public:
* adapter
*/
BindAdapterPtr& getBindAdapter() { return _pBindAdapter; }
// BindAdapter* getBindAdapter() { return _pBindAdapter; }
/**
*
@ -1338,6 +1167,9 @@ public:
*/
EnumConnectionType getType() const { return _enType; }
/**
*
*/
bool isEmptyConn() const {return _bEmptyConn;}
/**
@ -1345,6 +1177,8 @@ public:
*/
void tryInitAuthState(int initState);
friend class NetThread;
protected:
/**
@ -1397,14 +1231,13 @@ public:
* @param o
* @return int: <0:, 0:, 1:
*/
// int parseProtocol(recv_queue::queue_type &o);
int parseProtocol();
/**
*
* @param vtRecvData
*/
void insertRecvQueue(const shared_ptr<RecvContext> &recv);//recv_queue::queue_type &vRecvData);
void insertRecvQueue(const shared_ptr<RecvContext> &recv);
/**
* udp方式的连接
@ -1418,28 +1251,6 @@ public:
*/
bool isTcp() const { return _lfd != -1; }
friend class NetThread;
// private:
// /**
// * tcp发送数据
// */
// int tcpSend(const void* data, size_t len);
// int tcpWriteV(const std::vector<iovec>& buffers);
// /**
// * 清空buffer-slices
// * @param slices
// */
// void clearSlices(std::vector<TC_Slice>& slices);
// /**
// * 整理buffer-slices
// * @param slices
// * @param toSkippedBytes
// */
// void adjustSlices(std::vector<TC_Slice>& slices, size_t toSkippedBytes);
public:
/**
*
@ -1452,7 +1263,6 @@ public:
*
*/
BindAdapterPtr _pBindAdapter;
// BindAdapter *_pBindAdapter;
/**
* TC_Socket
@ -1487,12 +1297,12 @@ public:
/**
* buffer
*/
TC_NetWorkBuffer _recvbuffer;
TC_NetWorkBuffer _recvBuffer;
/**
* buffer
*/
TC_NetWorkBuffer _sendbuffer;
TC_NetWorkBuffer _sendBuffer;
/**
*
@ -2105,20 +1915,6 @@ public:
*/
void send(const shared_ptr<SendContext> &data);
// /**
// * 关闭连接
// * @param uid
// */
// void close(unsigned int uid, int fd);
// /**
// * 发送数据
// * @param uid
// * @param s
// */
// void send(unsigned int uid, const string &s, const string &ip, uint16_t port, int fd);
/**
*
* @param lfd
@ -2134,13 +1930,6 @@ public:
*/
unordered_map<int, BindAdapterPtr> getListenSocketInfo();
// /**
// * 获取监听socket信息
// *
// * @return map<int,ListenSocket>
// */
// map<int, BindAdapterPtr> getListenSocketInfo();
/**
*
*

View File

@ -50,76 +50,64 @@ public:
/**
* buffer
*/
class SendBuffer
class Buffer
{
protected:
vector<char> sendBuffer;
uint32_t sendPos = 0;
public:
SendBuffer() { }
SendBuffer(const vector<char> &sBuffer) : sendBuffer(sBuffer) {}
SendBuffer(const char *sBuffer, size_t length) : sendBuffer(sBuffer, sBuffer+length) {}
Buffer() { }
Buffer(const vector<char> &sBuffer) : _buffer(sBuffer) {}
Buffer(const char *sBuffer, size_t length) : _buffer(sBuffer, sBuffer+length) {}
void swap(vector<char> &buff)
void swap(vector<char> &buff, size_t pos = 0)
{
sendPos = 0;
buff.swap(sendBuffer);
_pos = pos;
buff.swap(_buffer);
}
void clear()
{
sendBuffer.clear();
sendPos = 0;
_buffer.clear();
_pos = 0;
}
bool empty() const
{
return sendBuffer.size() <= sendPos;
return _buffer.size() <= _pos;
}
void addBuffer(const vector<char> &buffer)
{
sendBuffer.insert(sendBuffer.end(), buffer.begin(), buffer.end());// += buffer;
_buffer.insert(_buffer.end(), buffer.begin(), buffer.end());
}
void assign(const char *buffer, size_t length)
void assign(const char *buffer, size_t length, size_t pos = 0)
{
sendBuffer.assign(buffer, buffer + length);
sendPos = 0;
}
vector<char> &getBuffer()
{
return sendBuffer;
_buffer.assign(buffer, buffer + length);
_pos = pos;
}
void setBuffer(const vector<char> &buff, int pos = 0)
{
sendBuffer = buff;
sendPos = pos;
_buffer = buff;
_pos = pos;
}
char *buffer()
{
return sendBuffer.data() + sendPos;
}
char *buffer() { return _buffer.data() + _pos; }
const char *buffer() const
{
return sendBuffer.data() + sendPos;
}
const char *buffer() const { return _buffer.data() + _pos; }
uint32_t length() const
{
return (uint32_t)(sendBuffer.size() - sendPos);
}
size_t length() const { return _buffer.size() - _pos; }
void add(uint32_t ret)
{
sendPos += ret;
assert(sendPos <= sendBuffer.size());
_pos += ret;
assert(_pos <= _buffer.size());
}
protected:
vector<char> _buffer;
uint32_t _pos = 0;
};
/**
@ -162,13 +150,13 @@ public:
* buffer
* @param buff
*/
void addBuffer(const vector<char>& buff);
void addBuffer(const shared_ptr<Buffer> & buff);
/**
* add & swap, copy
* buffer
* @param buff
*/
void addSwapBuffer(vector<char>& buff);
void addBuffer(const vector<char>& buff);
/**
* buffer
@ -201,9 +189,9 @@ public:
/**
* buffer拼接起来
* @return string
* @return const char *, buffer的指针, NULL
*/
void mergeBuffers();
const char * mergeBuffers();
/**
* buffer(buffer拼接起来, )
@ -345,6 +333,9 @@ public:
static TC_NetWorkBuffer::PACKET_TYPE parseEcho(TC_NetWorkBuffer&in, vector<char> &out);
protected:
void getBuffers(char *buffer, size_t length) const;
template<typename T>
T getValue() const
{
@ -420,17 +411,13 @@ protected:
/**
* buffer list
*/
std::list<std::vector<char>> _bufferList;
std::list<std::shared_ptr<Buffer>> _bufferList;
/**
* buffer剩余没解析的字节总数
*/
size_t _length = 0;
/**
* buffer的位置
*/
size_t _pos = 0;
};
}

View File

@ -20,7 +20,9 @@
#if TARS_SSL
#include <string>
#include "tc_sslmgr.h"
#include <vector>
#include "util/tc_network_buffer.h"
#include "util/tc_sslmgr.h"
struct ssl_st;
typedef struct ssl_st SSL;
@ -45,19 +47,17 @@ public:
/**
* @brief .
*/
TC_OpenSSL() :
_ssl(NULL),
_bHandshaked(false),
_isServer(false),
_err(0)
{
}
TC_OpenSSL();
/**
* @brief .
*/
~TC_OpenSSL();
// static SSL* newSSL(const std::string& ctxName);
static void getMemData(BIO* bio, TC_NetWorkBuffer& buf);
// static void getSSLHead(const char* data, char& type, unsigned short& ver, unsigned short& len)
static int doSSLRead(SSL* ssl, TC_NetWorkBuffer& out);
private:
/**
* @brief
@ -83,22 +83,16 @@ public:
*/
bool IsHandshaked() const;
/**
* @brief
* @return
*/
bool HasError() const;
/**
* @brief
*/
string* RecvBuffer();
TC_NetWorkBuffer * RecvBuffer() { return &_plainBuf; }
/**
* @brief
* @return
*/
std::string DoHandshake(const void* data = NULL, size_t size = 0);
int DoHandshake(TC_NetWorkBuffer &out, const void* data = NULL, size_t size = 0);
/**
* @brief
@ -106,7 +100,7 @@ public:
* @param size
* @return
*/
std::string Write(const void* data, size_t size);
int Write(const char* data, size_t size, TC_NetWorkBuffer &out);
/**
* @brief
@ -115,7 +109,7 @@ public:
* @param out
* @return
*/
bool Read(const void* data, size_t size, std::string& out);
int Read(const void* data, size_t size, TC_NetWorkBuffer &out);
private:
/**
@ -136,11 +130,7 @@ private:
/**
*
*/
std::string _plainBuf;
/**
* errno
*/
int _err;
TC_NetWorkBuffer _plainBuf;
};
} // end namespace tars

View File

@ -21,7 +21,8 @@
#include <map>
#include <string>
#include "util/tc_buffer.h"
// #include "util/tc_buffer.h"
#include "util/tc_network_buffer.h"
#include "util/tc_singleton.h"
struct bio_st;
@ -49,29 +50,31 @@ namespace tars
static const size_t kSSLHeadSize = 5;
// new ssl conn
SSL* NewSSL(const std::string& ctxName);
// fetch data from mem bio
void GetMemData(BIO* bio, TC_Buffer& buf);
// void GetMemData(BIO* bio, TC_NetWorkBuffer& buf);
// fetch ssl head info
void GetSSLHead(const char* data, char& type, unsigned short& ver, unsigned short& len);
// void GetSSLHead(const char* data, char& type, unsigned short& ver, unsigned short& len);
// read from ssl
bool DoSSLRead(SSL*, std::string& out);
// bool DoSSLRead(SSL*, std::string& out);
class SSLManager : public TC_Singleton<SSLManager>
class TC_SSLManager : public TC_Singleton<TC_SSLManager>
{
public:
static void GlobalInit();
SSLManager();
~SSLManager();
TC_SSLManager();
bool AddCtx(const std::string& name,
~TC_SSLManager();
SSL* newSSL(const std::string& ctxName);
bool addCtx(const std::string& name,
const std::string& cafile,
const std::string& certfile,
const std::string& keyfile,
bool verifyClient);
SSL_CTX* GetCtx(const std::string& name) const;
SSL_CTX* getCtx(const std::string& name) const;
private:

View File

@ -701,8 +701,8 @@ TC_EpollServer::Connection::Connection(TC_EpollServer::BindAdapter *pBindAdapter
, _timeout(timeout)
, _ip(ip)
, _port(port)
, _recvbuffer(this)
, _sendbuffer(this)
, _recvBuffer(this)
, _sendBuffer(this)
, _iHeaderLen(0)
, _bClose(false)
, _enType(EM_TCP)
@ -720,8 +720,8 @@ TC_EpollServer::Connection::Connection(BindAdapter *pBindAdapter, int fd)
, _lfd(-1)
, _timeout(2)
, _port(0)
, _recvbuffer(this)
, _sendbuffer(this)
, _recvBuffer(this)
, _sendBuffer(this)
, _iHeaderLen(0)
, _bClose(false)
, _enType(EM_UDP)
@ -740,7 +740,7 @@ TC_EpollServer::Connection::~Connection()
_pRecvBuffer = NULL;
}
// clearSlices(_sendbuffer);
// clearSlices(_sendBuffer);
if (isTcp())
{
assert(!_sock.isValid());
@ -808,59 +808,63 @@ int TC_EpollServer::Connection::parseProtocol()
{
try
{
while (!_recvbuffer.empty())
while (!_recvBuffer.empty())
{
//需要过滤首包包头
if(_iHeaderLen > 0)
{
if(_recvbuffer.getBufferLength() >= (unsigned) _iHeaderLen)
if(_recvBuffer.getBufferLength() >= (unsigned) _iHeaderLen)
{
vector<char> header;
_recvbuffer.getHeader(_iHeaderLen, header);
_recvBuffer.getHeader(_iHeaderLen, header);
_pBindAdapter->getHeaderFilterFunctor()(TC_NetWorkBuffer::PACKET_FULL, header);
_recvbuffer.moveHeader(_iHeaderLen);
_recvBuffer.moveHeader(_iHeaderLen);
_iHeaderLen = 0;
}
else
{
vector<char> header = _recvbuffer.getBuffers();
vector<char> header = _recvBuffer.getBuffers();
_pBindAdapter->getHeaderFilterFunctor()(TC_NetWorkBuffer::PACKET_LESS, header);
_iHeaderLen -= (int)_recvbuffer.getBufferLength();
_recvbuffer.clearBuffers();
_iHeaderLen -= (int)_recvBuffer.getBufferLength();
_recvBuffer.clearBuffers();
break;
}
}
// std::string* rbuf = &_recvbuffer;
TC_NetWorkBuffer *rbuf = &_recvBuffer;
#if TARS_SSL
// ssl connection
if (_pBindAdapter->getEndpoint().isSSL())
{
char buffer[BUFFER_SIZE] = {0x00};
std::string out;
// if (!_openssl->Read(_recvbuffer.data(), _recvbuffer.size(), out))
if (!_openssl->Read(buffer, BUFFER_SIZE, out))
const char * data = _recvBuffer.mergeBuffers();
// std::string out;
int ret = _openssl->Read(data, _recvBuffer.getBufferLength(), _sendBuffer);
if (ret != 0)
// if (!_openssl->Read(buffer, BUFFER_SIZE, out))
{
_pBindAdapter->getEpollServer()->error("[TARS][SSL_read failed");
return -1;
}
else
{
if (!out.empty())
this->send(out, "", 0);
// rbuf = _openssl->RecvBuffer();
if (!_sendBuffer.empty())
{
this->sendBuffer();
}
_recvbuffer.clearBuffers();
// _recvbuffer.clear();
rbuf = _openssl->RecvBuffer();
}
_recvBuffer.clearBuffers();
// _recvBuffer.clear();
}
#endif
// string ro;
vector<char> ro;
TC_NetWorkBuffer::PACKET_TYPE b = _pBindAdapter->getProtocol()(_recvbuffer, ro);
TC_NetWorkBuffer::PACKET_TYPE b = _pBindAdapter->getProtocol()(*rbuf, ro);
if(b == TC_NetWorkBuffer::PACKET_LESS)
{
//包不完全
@ -937,10 +941,10 @@ int TC_EpollServer::Connection::recvTcp()
else
{
totalRecv += iBytesReceived;
_recvbuffer.addBuffer(buffer, iBytesReceived);
_recvBuffer.addBuffer(buffer, iBytesReceived);
//字符串太长时, 强制解析协议
if (_recvbuffer.getBufferLength() > 8192) {
if (_recvBuffer.getBufferLength() > 8192) {
parseProtocol();
}
@ -996,7 +1000,7 @@ int TC_EpollServer::Connection::recvUdp()
if (_pBindAdapter->isIpAllow(_ip) == true)
{
//保存接收到数据
_recvbuffer.addBuffer(_pRecvBuffer, iBytesReceived);
_recvBuffer.addBuffer(_pRecvBuffer, iBytesReceived);
parseProtocol();
}
@ -1005,7 +1009,7 @@ int TC_EpollServer::Connection::recvUdp()
//udp ip无权限
_pBindAdapter->getEpollServer()->debug( "accept [" + _ip + ":" + TC_Common::tostr(_port) + "] [" + TC_Common::tostr(_lfd) + "] not allowed");
}
_recvbuffer.clearBuffers();
_recvBuffer.clearBuffers();
if(++recvCount > 100)
{
@ -1026,9 +1030,9 @@ int TC_EpollServer::Connection::recv()
int TC_EpollServer::Connection::sendBuffer()
{
while(!_sendbuffer.empty())
while(!_sendBuffer.empty())
{
pair<const char*, size_t> data = _sendbuffer.getBufferPointer();
pair<const char*, size_t> data = _sendBuffer.getBufferPointer();
assert(data.first != NULL);
@ -1049,7 +1053,7 @@ int TC_EpollServer::Connection::sendBuffer()
if(iBytesSent > 0)
{
_sendbuffer.moveHeader(iBytesSent);
_sendBuffer.moveHeader(iBytesSent);
}
//发送的数据小于需要发送的,break, 内核会再通知你的
@ -1060,7 +1064,7 @@ int TC_EpollServer::Connection::sendBuffer()
}
//需要关闭链接
if (_bClose && _sendbuffer.empty())
if (_bClose && _sendBuffer.empty())
{
_pBindAdapter->getEpollServer()->debug("send [" + _ip + ":" + TC_Common::tostr(_port) + "] close connection by user.");
return -2;
@ -1073,8 +1077,8 @@ int TC_EpollServer::Connection::sendBuffer()
int TC_EpollServer::Connection::sendTcp(const shared_ptr<SendContext> &sc)
{
//tcp的, 将buffer放到队列末尾
if (!sc->buffer().empty()) {
_sendbuffer.addSwapBuffer(sc->buffer());
if (!sc->buffer()->empty()) {
_sendBuffer.addBuffer(sc->buffer());
}
return sendBuffer();
@ -1083,7 +1087,7 @@ int TC_EpollServer::Connection::sendTcp(const shared_ptr<SendContext> &sc)
int TC_EpollServer::Connection::sendUdp(const shared_ptr<SendContext> &sc)
{
//udp的直接发送即可
int iRet = _sock.sendto((const void *) sc->buffer().data(), sc->buffer().size(), sc->ip(), sc->port(), 0);
int iRet = _sock.sendto((const void *) sc->buffer()->buffer(), sc->buffer()->length(), sc->ip(), sc->port(), 0);
if (iRet < 0)
{
_pBindAdapter->getEpollServer()->error("[TC_EpollServer::Connection] send [" + _ip + ":" + TC_Common::tostr(_port) + "] error");
@ -1116,7 +1120,7 @@ bool TC_EpollServer::Connection::setRecvBuffer(size_t nSize)
bool TC_EpollServer::Connection::setClose()
{
_bClose = true;
if (_sendbuffer.empty())
if (_sendBuffer.empty())
return true;
else
return false;
@ -1470,28 +1474,31 @@ void TC_EpollServer::NetThread::addTcpConnection(TC_EpollServer::Connection *cPt
cPtr->getBindAdapter()->getEpollServer()->info("[TARS][addTcpConnection ssl connection");
// 分配ssl对象, ctxName 放在obj proxy里
SSL* ssl = NewSSL("server");
SSL* ssl = TC_SSLManager::getInstance()->newSSL("server");
if (!ssl)
{
cPtr->getBindAdapter()->getEpollServer()->error("[TARS][SSL_accept not find server cert");
this->close(uid);
cPtr->close();
// this->close(uid);
return;
}
cPtr->_openssl.reset(new TC_OpenSSL());
cPtr->_openssl->Init(ssl, true);
std::string out = cPtr->_openssl->DoHandshake();
if (cPtr->_openssl->HasError())
int ret = cPtr->_openssl->DoHandshake(cPtr->_sendBuffer);
if (ret != 0)
{
cPtr->getBindAdapter()->getEpollServer()->error("[TARS][SSL_accept error: " + cPtr->getBindAdapter()->getEndpoint().toString());
this->close(uid);
cPtr->close();
// this->close(uid);
return;
}
// send the encrypt data from write buffer
if (!out.empty())
if (!cPtr->_sendBuffer.empty())
{
this->sendBuffer(cPtr, out, "", 0);
cPtr->sendBuffer();
}
}
#endif
@ -1621,18 +1628,30 @@ void TC_EpollServer::NetThread::processPipe()
}
case 's':
{
int ret = 0;
#if TARS_SSL
if (cPtr->getBindAdapter()->getEndpoint().isSSL() && cPtr->_openssl->IsHandshaked())
{
std::string out = cPtr->_openssl->Write((*it)->buffer.data(), (*it)->buffer.size());
if (cPtr->_openssl->HasError())
// std::string out = cPtr->_openssl->Write((*it)->buffer.data(), (*it)->buffer.size());
// if (cPtr->_openssl->HasError())
// break; // should not happen
//
// (*it)->buffer = out;
ret = cPtr->_openssl->Write(sc->buffer()->buffer(), sc->buffer()->length(), cPtr->_sendBuffer);
if (ret != 0)
break; // should not happen
(*it)->buffer = out;
cPtr->sendBuffer();
// (*it)->buffer = out;
}
else
{
ret = cPtr->send(sc);
}
#else
ret = cPtr->send(sc);
#endif
int ret = cPtr->send(sc);
if(ret < 0)
{
delConnection(cPtr,true,(ret==-1)?EM_CLIENT_CLOSE:EM_SERVER_CLOSE);

View File

@ -683,7 +683,7 @@ int TC_Http2Client::settings(unsigned int maxCurrentStreams)
// }
// }
// string& TC_Http2Client::sendBuffer()
// string& TC_Http2Client::_buffer()
// {
// return _sendBuf;
// }

View File

@ -9,17 +9,23 @@ using namespace std;
namespace tars
{
void TC_NetWorkBuffer::addSwapBuffer(vector<char>& buff)
{
_length += buff.size();
//void TC_NetWorkBuffer::addSwapBuffer(vector<char>& buff)
//{
// _length += buff.size();
//
// _bufferList.push_back(std::make_shared<Buffervector<char>());
// _bufferList.back().swap(buff);
//}
_bufferList.push_back(vector<char>());
_bufferList.back().swap(buff);
void TC_NetWorkBuffer::addBuffer(const shared_ptr<TC_NetWorkBuffer::Buffer> & buff)
{
_bufferList.push_back(buff);
_length += buff->length();
}
void TC_NetWorkBuffer::addBuffer(const vector<char>& buff)
{
_bufferList.push_back(buff);
_bufferList.push_back(std::make_shared<Buffer>(buff));
_length += buff.size();
}
@ -33,7 +39,6 @@ void TC_NetWorkBuffer::clearBuffers()
{
_bufferList.clear();
_length = 0;
_pos = 0;
}
bool TC_NetWorkBuffer::empty() const
@ -55,22 +60,52 @@ pair<const char*, size_t> TC_NetWorkBuffer::getBufferPointer() const
auto it = _bufferList.begin();
return make_pair(it->data() + _pos, it->size() - _pos);
return make_pair((*it)->buffer(), (*it)->length());
}
void TC_NetWorkBuffer::mergeBuffers()
const char * TC_NetWorkBuffer::mergeBuffers()
{
//merge to one buffer
if(_bufferList.size() > 1)
{
vector<char> buffer = getBuffers();
_pos = 0;
_bufferList.clear();
_bufferList.push_back(buffer);
addBuffer(buffer);
// _bufferList.push_back(buffer);
}
assert(_bufferList.size() <= 1);
if(!_bufferList.empty())
{
return (*_bufferList.begin())->buffer();
}
return NULL;
}
void TC_NetWorkBuffer::getBuffers(char *buffer, size_t length) const
{
assert(length <= getBufferLength());
auto it = _bufferList.begin();
size_t left = length;
size_t pos = 0;
while(it != _bufferList.end() || left == 0)
{
size_t len = std::min(left, (*it)->length());
memcpy(buffer + pos, (*it)->buffer(), len);
left -= len;
pos += len;
++it;
}
}
string TC_NetWorkBuffer::getBuffersString() const
@ -78,23 +113,7 @@ string TC_NetWorkBuffer::getBuffersString() const
string buffer;
buffer.resize(_length);
auto it = _bufferList.begin();
size_t pos = 0;
while(it != _bufferList.end())
{
if(it == _bufferList.begin())
{
memcpy(&buffer[pos], it->data() + _pos, it->size() - _pos);
pos += it->size() - _pos;
}
else
{
memcpy(&buffer[pos], it->data(), it->size());
pos += it->size();
}
++it;
}
getBuffers(&buffer[0], _length);
return buffer;
}
@ -104,23 +123,7 @@ vector<char> TC_NetWorkBuffer::getBuffers() const
vector<char> buffer;
buffer.resize(_length);
auto it = _bufferList.begin();
size_t pos = 0;
while(it != _bufferList.end())
{
if(it == _bufferList.begin())
{
memcpy(&buffer[pos], it->data() + _pos, it->size() - _pos);
pos += it->size() - _pos;
}
else
{
memcpy(&buffer[pos], it->data(), it->size());
pos += it->size();
}
++it;
}
getBuffers(&buffer[0], _length);
return buffer;
}
@ -138,31 +141,32 @@ bool TC_NetWorkBuffer::getHeader(size_t len, std::string &buffer) const
}
buffer.reserve(len);
auto it = _bufferList.begin();
size_t left = len;
size_t cur = _pos;
getBuffers(&buffer[0], len);
//
// auto it = _bufferList.begin();
//
// size_t left = len;
//
// while(it != _bufferList.end())
// {
// if((*it)->length() >= left)
// {
// //当前buffer足够
// buffer.append((*it)->buffer(), left);
// return true;
// }
// else
// {
// //当前buffer不够
// buffer.append((*it)->buffer(), (*it)->length());
// left = left - (*it)->length();
// }
//
// ++it;
// }
while(it != _bufferList.end())
{
if(it->size() - cur >= left)
{
//当前buffer足够
buffer.append(it->data() + cur, left);
return true;
}
else
{
//当前buffer不够
buffer.append(it->data() + cur, it->size() - cur);
left = left - (it->size() - cur);
cur = 0;
}
++it;
}
assert(buffer.length() == len);
// assert(buffer.length() == len);
return true;
}
@ -181,34 +185,35 @@ bool TC_NetWorkBuffer::getHeader(size_t len, std::vector<char> &buffer) const
buffer.reserve(len);
auto it = _bufferList.begin();
size_t left = len;
size_t cur = _pos;
while(it != _bufferList.end())
{
if(it->size() - cur >= left)
{
//当前buffer足够
buffer.insert(buffer.end(), it->data() + cur, it->data() + cur + left);
return true;
}
else
{
//当前buffer不够
buffer.insert(buffer.end(), it->data() + cur, it->data() + it->size());
left = left - (it->size() - cur);
cur = 0;
}
++it;
}
assert(buffer.size() == len);
getBuffers(&buffer[0], len);
//
// auto it = _bufferList.begin();
//
// size_t left = len;
//
// while(it != _bufferList.end())
// {
// if((*it)->length() >= left)
// {
// //当前buffer足够
// buffer.insert(buffer.end(), (*it)->buffer(), (*it)->buffer() + left);
// return true;
// }
// else
// {
// //当前buffer不够
// buffer.insert(buffer.end(), (*it)->buffer(), (*it)->buffer() + (*it)->length());
// left = left - (*it)->length();
// }
//
// ++it;
// }
//
// assert(buffer.size() == len);
return true;
}
bool TC_NetWorkBuffer::moveHeader(size_t len)
{
if(getBufferLength() < len)
@ -219,24 +224,22 @@ bool TC_NetWorkBuffer::moveHeader(size_t len)
auto it = _bufferList.begin();
assert(it->size() >= _pos);
// assert(it->size() >= _pos);
size_t left = it->size() - _pos;
size_t left = (*it)->length();
if(left > len)
{
_pos += len;
(*it)->add(len);
_length -= len;
}
else if(left == len)
{
_pos = 0;
_length -= len;
_bufferList.erase(it);
}
else
{
_pos = 0;
_length -= left;
_bufferList.erase(it);
@ -308,6 +311,7 @@ TC_NetWorkBuffer::PACKET_TYPE TC_NetWorkBuffer::parseHttp(TC_NetWorkBuffer&in, v
if (b == PACKET_FULL)
{
out = in.getBuffers();
in.clearBuffers();
}
return b;
@ -318,12 +322,8 @@ TC_NetWorkBuffer::PACKET_TYPE TC_NetWorkBuffer::parseEcho(TC_NetWorkBuffer&in, v
{
try
{
if(in.empty())
{
return PACKET_LESS;
}
out = in.getBuffers();
in.clearBuffers();
return TC_NetWorkBuffer::PACKET_FULL;
}
catch (exception &ex)

View File

@ -20,12 +20,20 @@
#include <openssl/err.h>
#include "util/tc_openssl.h"
#include "util/tc_buffer.h"
//#include "util/tc_buffer.h"
namespace tars
{
TC_OpenSSL::TC_OpenSSL()
: _ssl(NULL)
, _bHandshaked(false)
, _isServer(false)
, _plainBuf(NULL)
{
}
TC_OpenSSL::~TC_OpenSSL()
{
Release();
@ -39,7 +47,7 @@ void TC_OpenSSL::Release()
_ssl = NULL;
}
_bHandshaked = false;
_err = 0;
// _err = 0;
}
void TC_OpenSSL::Init(SSL* ssl, bool isServer)
@ -48,7 +56,7 @@ void TC_OpenSSL::Init(SSL* ssl, bool isServer)
_ssl = ssl;
_bHandshaked = false;
_isServer = isServer;
_err = 0;
// _err = 0;
}
bool TC_OpenSSL::IsHandshaked() const
@ -56,17 +64,12 @@ bool TC_OpenSSL::IsHandshaked() const
return _bHandshaked;
}
bool TC_OpenSSL::HasError() const
{
return _err != 0;
}
//bool TC_OpenSSL::HasError() const
//{
// return _err != 0;
//}
string* TC_OpenSSL::RecvBuffer()
{
return &_plainBuf;
}
std::string TC_OpenSSL::DoHandshake(const void* data, size_t size)
int TC_OpenSSL::DoHandshake(TC_NetWorkBuffer &out, const void* data, size_t size)
{
assert (!_bHandshaked);
assert (_ssl);
@ -80,66 +83,70 @@ std::string TC_OpenSSL::DoHandshake(const void* data, size_t size)
ERR_clear_error();
int ret = _isServer ? SSL_accept(_ssl) : SSL_connect(_ssl);
int err = 0;
if (ret <= 0)
{
_err = SSL_get_error(_ssl, ret);
if (_err != SSL_ERROR_WANT_READ)
err = SSL_get_error(_ssl, ret);
if (err != SSL_ERROR_WANT_READ)
{
return std::string();
return err;
}
}
_err = 0;
if (ret == 1)
{
_bHandshaked = true;
}
// the encrypted data from write buffer
std::string out;
TC_Buffer outdata;
GetMemData(SSL_get_wbio(_ssl), outdata);
if (!outdata.IsEmpty())
{
out.assign(outdata.ReadAddr(), outdata.ReadableSize());
// vector<char> out;
// TC_Buffer outdata;
getMemData(SSL_get_wbio(_ssl), out);
// if (!outdata.IsEmpty())
// {
// out.assign(outdata.ReadAddr(), outdata.ReadableSize());
// }
// return out;
return 0;
}
return out;
}
std::string TC_OpenSSL::Write(const void* data, size_t size)
int TC_OpenSSL::Write(const char* data, size_t size, TC_NetWorkBuffer &out)
{
if (!_bHandshaked)
return std::string((const char*)data, size); //握手数据不用加密
{
//握手数据不用加密
out.addBuffer(data, size);
return 0;
}
// 会话数据需加密
ERR_clear_error();
int ret = SSL_write(_ssl, data, size);
if (ret <= 0)
{
_err = SSL_get_error(_ssl, ret);
return std::string();
return SSL_get_error(_ssl, ret);
}
// _err = 0;
// TC_Buffer toSend;
getMemData(SSL_get_wbio(_ssl), out);
return 0;
// return std::string(toSend.ReadAddr(), toSend.ReadableSize());
}
_err = 0;
TC_Buffer toSend;
GetMemData(SSL_get_wbio(_ssl), toSend);
return std::string(toSend.ReadAddr(), toSend.ReadableSize());
}
bool TC_OpenSSL::Read(const void* data, size_t size, std::string& out)
int TC_OpenSSL::Read(const void* data, size_t size, TC_NetWorkBuffer &out)
{
bool usedData = false;
if (!_bHandshaked)
{
usedData = true;
_plainBuf.clear();
std::string out2 = DoHandshake(data, size);
out.swap(out2);
_plainBuf.clearBuffers();
int ret = DoHandshake(out, data, size);
if (_err != 0)
if (ret != 0)
return false;
if (_bHandshaked)
@ -155,19 +162,86 @@ bool TC_OpenSSL::Read(const void* data, size_t size, std::string& out)
BIO_write(SSL_get_rbio(_ssl), data, size);
}
string data;
if (DoSSLRead(_ssl, data))
if (!doSSLRead(_ssl, _plainBuf))
{
_plainBuf.append(data.begin(), data.end());
}
else
{
_err = SSL_ERROR_SSL;
return false;
return SSL_ERROR_SSL;
}
}
return true;
return 0;
}
//
//SSL* TC_OpenSSL::newSSL(const std::string& ctxName)
//{
// SSL_CTX* ctx = TC_SSLManager::getInstance()->GetCtx(ctxName);
// if (!ctx)
// return NULL;
//
// SSL* ssl = SSL_new(ctx);
//
// SSL_set_mode(ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); // allow retry ssl-write with different args
// SSL_set_bio(ssl, BIO_new(BIO_s_mem()), BIO_new(BIO_s_mem()));
//
// BIO_set_mem_eof_return(SSL_get_rbio(ssl), -1);
// BIO_set_mem_eof_return(SSL_get_wbio(ssl), -1);
//
// return ssl;
//}
void TC_OpenSSL::getMemData(BIO* bio, TC_NetWorkBuffer& buf)
{
while (true)
{
char data[8*1024];
int bytes = BIO_read(bio, data, sizeof(data));
if (bytes <= 0)
return;
buf.addBuffer(data, bytes);
}
}
//
//void TC_OpenSSL::getSSLHead(const char* data, char& type, unsigned short& ver, unsigned short& len)
//{
// type = data[0];
// ver = *(unsigned short*)(data + 1);
// len = *(unsigned short*)(data + 3);
//
// ver = ntohs(ver);
// len = ntohs(len);
//}
int TC_OpenSSL::doSSLRead(SSL* ssl, TC_NetWorkBuffer& out)
{
while (true)
{
char plainBuf[32 * 1024];
ERR_clear_error();
int bytes = SSL_read(ssl, plainBuf, sizeof plainBuf);
if (bytes > 0)
{
out.addBuffer(plainBuf, bytes);
}
else
{
int err = SSL_get_error(ssl, bytes);
// when peer issued renegotiation, here will demand us to send handshake data.
// write to mem bio will always success, only need to check whether has data to send.
//assert (err != SSL_ERROR_WANT_WRITE);
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_ZERO_RETURN)
{
// printf("DoSSLRead err %d\n", err);
return err;
}
break;
}
}
return 0;
}
} // end namespace tars

View File

@ -17,19 +17,19 @@
#if TARS_SSL
#include "util/tc_sslmgr.h"
#include "util/tc_buffer.h"
#include <arpa/inet.h>
// #include "util/tc_buffer.h"
// #include <arpa/inet.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
namespace tars
{
SSLManager::SSLManager()
TC_SSLManager::TC_SSLManager()
{
}
void SSLManager::GlobalInit()
void TC_SSLManager::GlobalInit()
{
(void)SSL_library_init();
OpenSSL_add_all_algorithms();
@ -39,7 +39,7 @@ void SSLManager::GlobalInit()
}
SSLManager::~SSLManager()
TC_SSLManager::~TC_SSLManager()
{
for (CTX_MAP::iterator it(_ctxSet.begin());
it != _ctxSet.end();
@ -52,7 +52,7 @@ SSLManager::~SSLManager()
EVP_cleanup();
}
bool SSLManager::AddCtx(const std::string& name,
bool TC_SSLManager::addCtx(const std::string& name,
const std::string& cafile,
const std::string& certfile,
const std::string& keyfile,
@ -99,17 +99,15 @@ bool SSLManager::AddCtx(const std::string& name,
return _ctxSet.insert(std::make_pair(name, ctx)).second;
}
SSL_CTX* SSLManager::GetCtx(const std::string& name) const
SSL_CTX* TC_SSLManager::getCtx(const std::string& name) const
{
CTX_MAP::const_iterator it = _ctxSet.find(name);
return it == _ctxSet.end() ? NULL: it->second;
}
SSL* NewSSL(const std::string& ctxName)
SSL* TC_SSLManager::newSSL(const std::string& ctxName)
{
SSL_CTX* ctx = SSLManager::getInstance()->GetCtx(ctxName);
SSL_CTX* ctx = TC_SSLManager::getInstance()->getCtx(ctxName);
if (!ctx)
return NULL;
@ -123,64 +121,62 @@ SSL* NewSSL(const std::string& ctxName)
return ssl;
}
void GetMemData(BIO* bio, TC_Buffer& buf)
{
while (true)
{
buf.AssureSpace(16 * 1024);
int bytes = BIO_read(bio, buf.WriteAddr(), buf.WritableSize());
if (bytes <= 0)
return;
buf.Produce(bytes);
}
// never here
}
void GetSSLHead(const char* data, char& type, unsigned short& ver, unsigned short& len)
{
type = data[0];
ver = *(unsigned short*)(data + 1);
len = *(unsigned short*)(data + 3);
ver = ntohs(ver);
len = ntohs(len);
}
bool DoSSLRead(SSL* ssl, std::string& out)
{
while (true)
{
char plainBuf[32 * 1024];
ERR_clear_error();
int bytes = SSL_read(ssl, plainBuf, sizeof plainBuf);
if (bytes > 0)
{
out.append(plainBuf, bytes);
}
else
{
int err = SSL_get_error(ssl, bytes);
// when peer issued renegotiation, here will demand us to send handshake data.
// write to mem bio will always success, only need to check whether has data to send.
//assert (err != SSL_ERROR_WANT_WRITE);
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_ZERO_RETURN)
{
printf("DoSSLRead err %d\n", err);
return false;
}
break;
}
}
return true;
}
//
//void GetMemData(BIO* bio, TC_NetWorkBuffer& buf)
//{
// while (true)
// {
// char data[8*1024];
// int bytes = BIO_read(bio, data, sizeof(data));
// if (bytes <= 0)
// return;
//
// buf.addBuffer(data, bytes);
// }
//}
//
//void GetSSLHead(const char* data, char& type, unsigned short& ver, unsigned short& len)
//{
// type = data[0];
// ver = *(unsigned short*)(data + 1);
// len = *(unsigned short*)(data + 3);
//
// ver = ntohs(ver);
// len = ntohs(len);
//}
//
//bool DoSSLRead(SSL* ssl, std::string& out)
//{
// while (true)
// {
// char plainBuf[32 * 1024];
//
// ERR_clear_error();
// int bytes = SSL_read(ssl, plainBuf, sizeof plainBuf);
// if (bytes > 0)
// {
// out.append(plainBuf, bytes);
// }
// else
// {
// int err = SSL_get_error(ssl, bytes);
//
// // when peer issued renegotiation, here will demand us to send handshake data.
// // write to mem bio will always success, only need to check whether has data to send.
// //assert (err != SSL_ERROR_WANT_WRITE);
//
// if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_ZERO_RETURN)
// {
// printf("DoSSLRead err %d\n", err);
// return false;
// }
//
// break;
// }
// }
//
// return true;
//}
} // end namespace tars