415 lines
11 KiB
C++
415 lines
11 KiB
C++
#include "tcp_utility.hpp"
|
|
#include <sys/socket.h>
|
|
#include <arpa/inet.h>
|
|
#include <unistd.h>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
#include <thread>
|
|
#include <netdb.h>
|
|
|
|
namespace tcp {
|
|
|
|
// TcpSocket 实现
|
|
TcpSocket::TcpSocket(int socket_fd)
|
|
: m_socket_fd(socket_fd), m_connected(socket_fd != -1) {
|
|
if (m_connected) {
|
|
// 获取对端地址和端口
|
|
struct sockaddr_in addr;
|
|
socklen_t addr_len = sizeof(addr);
|
|
if (getpeername(m_socket_fd, (struct sockaddr*)&addr, &addr_len) == 0) {
|
|
m_peer_address = inet_ntoa(addr.sin_addr);
|
|
m_peer_port = ntohs(addr.sin_port);
|
|
}
|
|
}
|
|
}
|
|
|
|
TcpSocket::~TcpSocket() {
|
|
close();
|
|
}
|
|
|
|
TcpSocket::TcpSocket(TcpSocket&& other) noexcept
|
|
: m_socket_fd(other.m_socket_fd),
|
|
m_connected(other.m_connected),
|
|
m_peer_address(std::move(other.m_peer_address)),
|
|
m_peer_port(other.m_peer_port),
|
|
m_message_callback(std::move(other.m_message_callback)),
|
|
m_error_callback(std::move(other.m_error_callback)) {
|
|
other.m_socket_fd = -1;
|
|
other.m_connected = false;
|
|
}
|
|
|
|
TcpSocket& TcpSocket::operator=(TcpSocket&& other) noexcept {
|
|
if (this != &other) {
|
|
close();
|
|
m_socket_fd = other.m_socket_fd;
|
|
m_connected = other.m_connected;
|
|
m_peer_address = std::move(other.m_peer_address);
|
|
m_peer_port = other.m_peer_port;
|
|
m_message_callback = std::move(other.m_message_callback);
|
|
m_error_callback = std::move(other.m_error_callback);
|
|
other.m_socket_fd = -1;
|
|
other.m_connected = false;
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
ErrorCode TcpSocket::connect(const std::string& host, uint16_t port) {
|
|
close();
|
|
|
|
// 创建套接字
|
|
m_socket_fd = socket(AF_INET, SOCK_STREAM, 0);
|
|
if (m_socket_fd == -1) {
|
|
return ErrorCode::SocketCreationFailed;
|
|
}
|
|
|
|
// 设置服务器地址
|
|
struct sockaddr_in server_addr;
|
|
memset(&server_addr, 0, sizeof(server_addr));
|
|
server_addr.sin_family = AF_INET;
|
|
server_addr.sin_port = htons(port);
|
|
|
|
// 解析主机名
|
|
struct hostent* host_entry = gethostbyname(host.c_str());
|
|
if (host_entry == nullptr) {
|
|
close();
|
|
return ErrorCode::ConnectFailed;
|
|
}
|
|
|
|
memcpy(&server_addr.sin_addr, host_entry->h_addr_list[0], host_entry->h_length);
|
|
|
|
// 连接服务器
|
|
if (::connect(m_socket_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) {
|
|
close();
|
|
return ErrorCode::ConnectFailed;
|
|
}
|
|
|
|
m_connected = true;
|
|
m_peer_address = host;
|
|
m_peer_port = port;
|
|
|
|
return ErrorCode::Success;
|
|
}
|
|
|
|
ErrorCode TcpSocket::send(const std::vector<unsigned char>& data) {
|
|
if (!m_connected || m_socket_fd == -1) {
|
|
return ErrorCode::Disconnected;
|
|
}
|
|
|
|
size_t total_sent = 0;
|
|
while (total_sent < data.size()) {
|
|
ssize_t sent = ::send(m_socket_fd, data.data() + total_sent,
|
|
data.size() - total_sent, 0);
|
|
if (sent == -1) {
|
|
close();
|
|
if (m_error_callback) {
|
|
m_error_callback(*this, ErrorCode::SendFailed);
|
|
}
|
|
return ErrorCode::SendFailed;
|
|
}
|
|
total_sent += sent;
|
|
}
|
|
|
|
return ErrorCode::Success;
|
|
}
|
|
|
|
ErrorCode TcpSocket::send(const std::string& data) {
|
|
return send(std::vector<unsigned char>(data.begin(), data.end()));
|
|
}
|
|
|
|
ErrorCode TcpSocket::receive(std::vector<unsigned char>& buffer, size_t max_size) {
|
|
if (!m_connected || m_socket_fd == -1) {
|
|
return ErrorCode::Disconnected;
|
|
}
|
|
|
|
buffer.resize(max_size);
|
|
ssize_t received = ::recv(m_socket_fd, buffer.data(), max_size, 0);
|
|
|
|
if (received > 0) {
|
|
buffer.resize(received);
|
|
return ErrorCode::Success;
|
|
} else if (received == 0) {
|
|
// 连接关闭
|
|
close();
|
|
if (m_error_callback) {
|
|
m_error_callback(*this, ErrorCode::Disconnected);
|
|
}
|
|
return ErrorCode::Disconnected;
|
|
} else {
|
|
// 错误发生
|
|
close();
|
|
if (m_error_callback) {
|
|
m_error_callback(*this, ErrorCode::ReceiveFailed);
|
|
}
|
|
return ErrorCode::ReceiveFailed;
|
|
}
|
|
}
|
|
|
|
void TcpSocket::close() {
|
|
if (m_socket_fd != -1) {
|
|
::close(m_socket_fd);
|
|
m_socket_fd = -1;
|
|
m_connected = false;
|
|
}
|
|
}
|
|
|
|
bool TcpSocket::isConnected() const {
|
|
return m_connected;
|
|
}
|
|
|
|
std::string TcpSocket::getPeerAddress() const {
|
|
return m_peer_address;
|
|
}
|
|
|
|
uint16_t TcpSocket::getPeerPort() const {
|
|
return m_peer_port;
|
|
}
|
|
|
|
void TcpSocket::setMessageCallback(const MessageCallback& callback) {
|
|
m_message_callback = callback;
|
|
}
|
|
|
|
void TcpSocket::setErrorCallback(const ErrorCallback& callback) {
|
|
m_error_callback = callback;
|
|
}
|
|
|
|
void TcpSocket::startReceiving() {
|
|
if (!m_connected || m_socket_fd == -1) {
|
|
return;
|
|
}
|
|
|
|
std::thread([this]() {
|
|
std::vector<unsigned char> buffer(4096);
|
|
while (isConnected()) {
|
|
ErrorCode result = receive(buffer, buffer.capacity());
|
|
if (result == ErrorCode::Success && m_message_callback) {
|
|
m_message_callback(*this, buffer);
|
|
} else if (result != ErrorCode::Success) {
|
|
break;
|
|
}
|
|
}
|
|
}).detach();
|
|
}
|
|
|
|
// TcpClient 实现
|
|
TcpClient::TcpClient() {
|
|
m_socket = std::make_shared<TcpSocket>();
|
|
}
|
|
|
|
TcpClient::~TcpClient() {
|
|
disconnect();
|
|
}
|
|
|
|
ErrorCode TcpClient::connect(const std::string& host, uint16_t port) {
|
|
ErrorCode result = m_socket->connect(host, port);
|
|
if (result == ErrorCode::Success) {
|
|
m_socket->setMessageCallback(m_message_callback);
|
|
m_socket->setErrorCallback(m_error_callback);
|
|
m_socket->startReceiving();
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void TcpClient::disconnect() {
|
|
m_socket->close();
|
|
}
|
|
|
|
ErrorCode TcpClient::send(const std::vector<unsigned char>& data) {
|
|
return m_socket->send(data);
|
|
}
|
|
|
|
ErrorCode TcpClient::send(const std::string& data) {
|
|
return m_socket->send(data);
|
|
}
|
|
|
|
bool TcpClient::isConnected() const {
|
|
return m_socket->isConnected();
|
|
}
|
|
|
|
std::shared_ptr<TcpSocket> TcpClient::getSocket() const {
|
|
return m_socket;
|
|
}
|
|
|
|
void TcpClient::setMessageCallback(const MessageCallback& callback) {
|
|
m_message_callback = callback;
|
|
if (m_socket) {
|
|
m_socket->setMessageCallback(callback);
|
|
}
|
|
}
|
|
|
|
void TcpClient::setErrorCallback(const ErrorCallback& callback) {
|
|
m_error_callback = callback;
|
|
if (m_socket) {
|
|
m_socket->setErrorCallback(callback);
|
|
}
|
|
}
|
|
|
|
// TcpServer 实现
|
|
TcpServer::TcpServer()
|
|
: m_listen_fd(-1), m_port(0), m_running(false) {
|
|
}
|
|
|
|
TcpServer::~TcpServer() {
|
|
stop();
|
|
}
|
|
|
|
ErrorCode TcpServer::start(uint16_t port, size_t max_connections) {
|
|
if (m_running) {
|
|
return ErrorCode::Success;
|
|
}
|
|
|
|
// 创建套接字
|
|
m_listen_fd = socket(AF_INET, SOCK_STREAM, 0);
|
|
if (m_listen_fd == -1) {
|
|
return ErrorCode::SocketCreationFailed;
|
|
}
|
|
|
|
// 设置套接字选项
|
|
int opt = 1;
|
|
if (setsockopt(m_listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
|
|
close(m_listen_fd);
|
|
m_listen_fd = -1;
|
|
return ErrorCode::SocketCreationFailed;
|
|
}
|
|
|
|
// 绑定地址
|
|
struct sockaddr_in server_addr;
|
|
memset(&server_addr, 0, sizeof(server_addr));
|
|
server_addr.sin_family = AF_INET;
|
|
server_addr.sin_addr.s_addr = INADDR_ANY;
|
|
server_addr.sin_port = htons(port);
|
|
|
|
if (bind(m_listen_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) {
|
|
close(m_listen_fd);
|
|
m_listen_fd = -1;
|
|
return ErrorCode::BindFailed;
|
|
}
|
|
|
|
// 监听连接
|
|
if (listen(m_listen_fd, max_connections) == -1) {
|
|
close(m_listen_fd);
|
|
m_listen_fd = -1;
|
|
return ErrorCode::ListenFailed;
|
|
}
|
|
|
|
m_port = port;
|
|
m_running = true;
|
|
|
|
// 启动接受连接的线程
|
|
std::thread(&TcpServer::acceptThread, this).detach();
|
|
|
|
return ErrorCode::Success;
|
|
}
|
|
|
|
void TcpServer::stop() {
|
|
if (!m_running) {
|
|
return;
|
|
}
|
|
|
|
m_running = false;
|
|
|
|
// 关闭监听套接字
|
|
if (m_listen_fd != -1) {
|
|
close(m_listen_fd);
|
|
m_listen_fd = -1;
|
|
}
|
|
|
|
// 关闭所有客户端连接
|
|
for (auto& client : m_clients) {
|
|
client->close();
|
|
}
|
|
m_clients.clear();
|
|
}
|
|
|
|
bool TcpServer::isRunning() const {
|
|
return m_running;
|
|
}
|
|
|
|
uint16_t TcpServer::getPort() const {
|
|
return m_port;
|
|
}
|
|
|
|
void TcpServer::setConnectionCallback(const ConnectionCallback& callback) {
|
|
m_connection_callback = callback;
|
|
}
|
|
|
|
void TcpServer::setMessageCallback(const MessageCallback& callback) {
|
|
m_message_callback = callback;
|
|
}
|
|
|
|
void TcpServer::setErrorCallback(const ErrorCallback& callback) {
|
|
m_error_callback = callback;
|
|
}
|
|
|
|
void TcpServer::acceptThread() {
|
|
while (m_running) {
|
|
struct sockaddr_in client_addr;
|
|
socklen_t client_addr_len = sizeof(client_addr);
|
|
|
|
int client_fd = accept(m_listen_fd, (struct sockaddr*)&client_addr, &client_addr_len);
|
|
if (client_fd == -1) {
|
|
if (m_running) {
|
|
if (m_error_callback) {
|
|
// 这里没有具体的 TcpSocket 对象,所以传递一个空的
|
|
TcpSocket empty_socket;
|
|
m_error_callback(empty_socket, ErrorCode::AcceptFailed);
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// 创建新的客户端套接字
|
|
auto client_socket = std::make_shared<TcpSocket>(client_fd);
|
|
|
|
// 设置回调函数
|
|
client_socket->setMessageCallback([this](TcpSocket& socket, const std::vector<unsigned char>& data) {
|
|
if (m_message_callback) {
|
|
m_message_callback(socket, data);
|
|
}
|
|
});
|
|
|
|
client_socket->setErrorCallback([this, client_socket](TcpSocket& socket, ErrorCode error) {
|
|
// 从客户端列表中移除断开的连接
|
|
for (auto it = m_clients.begin(); it != m_clients.end(); ++it) {
|
|
if (it->get() == &socket) {
|
|
m_clients.erase(it);
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (m_error_callback) {
|
|
m_error_callback(socket, error);
|
|
}
|
|
});
|
|
|
|
// 添加到客户端列表
|
|
m_clients.push_back(client_socket);
|
|
|
|
// 通知有新连接
|
|
if (m_connection_callback) {
|
|
m_connection_callback(*this, client_socket);
|
|
}
|
|
|
|
// 开始接收该客户端的数据
|
|
client_socket->startReceiving();
|
|
}
|
|
}
|
|
|
|
// 错误处理辅助函数
|
|
std::string errorCodeToString(ErrorCode error) {
|
|
switch (error) {
|
|
case ErrorCode::Success: return "Success";
|
|
case ErrorCode::SocketCreationFailed: return "Socket creation failed";
|
|
case ErrorCode::BindFailed: return "Bind failed";
|
|
case ErrorCode::ListenFailed: return "Listen failed";
|
|
case ErrorCode::AcceptFailed: return "Accept failed";
|
|
case ErrorCode::ConnectFailed: return "Connect failed";
|
|
case ErrorCode::SendFailed: return "Send failed";
|
|
case ErrorCode::ReceiveFailed: return "Receive failed";
|
|
case ErrorCode::Disconnected: return "Disconnected";
|
|
case ErrorCode::InvalidArgument: return "Invalid argument";
|
|
default: return "Unknown error";
|
|
}
|
|
}
|
|
|
|
} // namespace tcp
|