#include "tcp_utility.hpp" #include #include #include #include #include #include #include namespace tcp { // TcpSocket 实现 TcpSocket::TcpSocket(int socket_fd) : m_socket_fd(socket_fd), m_connected(socket_fd != -1) { if (m_connected) { // 获取对端地址和端口 struct sockaddr_in addr; socklen_t addr_len = sizeof(addr); if (getpeername(m_socket_fd, (struct sockaddr*)&addr, &addr_len) == 0) { m_peer_address = inet_ntoa(addr.sin_addr); m_peer_port = ntohs(addr.sin_port); } } } TcpSocket::~TcpSocket() { close(); } TcpSocket::TcpSocket(TcpSocket&& other) noexcept : m_socket_fd(other.m_socket_fd), m_connected(other.m_connected), m_peer_address(std::move(other.m_peer_address)), m_peer_port(other.m_peer_port), m_message_callback(std::move(other.m_message_callback)), m_error_callback(std::move(other.m_error_callback)) { other.m_socket_fd = -1; other.m_connected = false; } TcpSocket& TcpSocket::operator=(TcpSocket&& other) noexcept { if (this != &other) { close(); m_socket_fd = other.m_socket_fd; m_connected = other.m_connected; m_peer_address = std::move(other.m_peer_address); m_peer_port = other.m_peer_port; m_message_callback = std::move(other.m_message_callback); m_error_callback = std::move(other.m_error_callback); other.m_socket_fd = -1; other.m_connected = false; } return *this; } ErrorCode TcpSocket::connect(const std::string& host, uint16_t port) { close(); // 创建套接字 m_socket_fd = socket(AF_INET, SOCK_STREAM, 0); if (m_socket_fd == -1) { return ErrorCode::SocketCreationFailed; } // 设置服务器地址 struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(port); // 解析主机名 struct hostent* host_entry = gethostbyname(host.c_str()); if (host_entry == nullptr) { close(); return ErrorCode::ConnectFailed; } memcpy(&server_addr.sin_addr, host_entry->h_addr_list[0], host_entry->h_length); // 连接服务器 if (::connect(m_socket_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) { close(); return ErrorCode::ConnectFailed; } m_connected = true; m_peer_address = host; m_peer_port = port; return ErrorCode::Success; } ErrorCode TcpSocket::send(const std::vector& data) { if (!m_connected || m_socket_fd == -1) { return ErrorCode::Disconnected; } size_t total_sent = 0; while (total_sent < data.size()) { ssize_t sent = ::send(m_socket_fd, data.data() + total_sent, data.size() - total_sent, 0); if (sent == -1) { close(); if (m_error_callback) { m_error_callback(*this, ErrorCode::SendFailed); } return ErrorCode::SendFailed; } total_sent += sent; } return ErrorCode::Success; } ErrorCode TcpSocket::send(const std::string& data) { return send(std::vector(data.begin(), data.end())); } ErrorCode TcpSocket::receive(std::vector& buffer, size_t max_size) { if (!m_connected || m_socket_fd == -1) { return ErrorCode::Disconnected; } buffer.resize(max_size); ssize_t received = ::recv(m_socket_fd, buffer.data(), max_size, 0); if (received > 0) { buffer.resize(received); return ErrorCode::Success; } else if (received == 0) { // 连接关闭 close(); if (m_error_callback) { m_error_callback(*this, ErrorCode::Disconnected); } return ErrorCode::Disconnected; } else { // 错误发生 close(); if (m_error_callback) { m_error_callback(*this, ErrorCode::ReceiveFailed); } return ErrorCode::ReceiveFailed; } } void TcpSocket::close() { if (m_socket_fd != -1) { ::close(m_socket_fd); m_socket_fd = -1; m_connected = false; } } bool TcpSocket::isConnected() const { return m_connected; } std::string TcpSocket::getPeerAddress() const { return m_peer_address; } uint16_t TcpSocket::getPeerPort() const { return m_peer_port; } void TcpSocket::setMessageCallback(const MessageCallback& callback) { m_message_callback = callback; } void TcpSocket::setErrorCallback(const ErrorCallback& callback) { m_error_callback = callback; } void TcpSocket::startReceiving() { if (!m_connected || m_socket_fd == -1) { return; } std::thread([this]() { std::vector buffer(4096); while (isConnected()) { ErrorCode result = receive(buffer, buffer.capacity()); if (result == ErrorCode::Success && m_message_callback) { m_message_callback(*this, buffer); } else if (result != ErrorCode::Success) { break; } } }).detach(); } // TcpClient 实现 TcpClient::TcpClient() { m_socket = std::make_shared(); } TcpClient::~TcpClient() { disconnect(); } ErrorCode TcpClient::connect(const std::string& host, uint16_t port) { ErrorCode result = m_socket->connect(host, port); if (result == ErrorCode::Success) { m_socket->setMessageCallback(m_message_callback); m_socket->setErrorCallback(m_error_callback); m_socket->startReceiving(); } return result; } void TcpClient::disconnect() { m_socket->close(); } ErrorCode TcpClient::send(const std::vector& data) { return m_socket->send(data); } ErrorCode TcpClient::send(const std::string& data) { return m_socket->send(data); } bool TcpClient::isConnected() const { return m_socket->isConnected(); } std::shared_ptr TcpClient::getSocket() const { return m_socket; } void TcpClient::setMessageCallback(const MessageCallback& callback) { m_message_callback = callback; if (m_socket) { m_socket->setMessageCallback(callback); } } void TcpClient::setErrorCallback(const ErrorCallback& callback) { m_error_callback = callback; if (m_socket) { m_socket->setErrorCallback(callback); } } // TcpServer 实现 TcpServer::TcpServer() : m_listen_fd(-1), m_port(0), m_running(false) { } TcpServer::~TcpServer() { stop(); } ErrorCode TcpServer::start(uint16_t port, size_t max_connections) { if (m_running) { return ErrorCode::Success; } // 创建套接字 m_listen_fd = socket(AF_INET, SOCK_STREAM, 0); if (m_listen_fd == -1) { return ErrorCode::SocketCreationFailed; } // 设置套接字选项 int opt = 1; if (setsockopt(m_listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { close(m_listen_fd); m_listen_fd = -1; return ErrorCode::SocketCreationFailed; } // 绑定地址 struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); if (bind(m_listen_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) { close(m_listen_fd); m_listen_fd = -1; return ErrorCode::BindFailed; } // 监听连接 if (listen(m_listen_fd, max_connections) == -1) { close(m_listen_fd); m_listen_fd = -1; return ErrorCode::ListenFailed; } m_port = port; m_running = true; // 启动接受连接的线程 std::thread(&TcpServer::acceptThread, this).detach(); return ErrorCode::Success; } void TcpServer::stop() { if (!m_running) { return; } m_running = false; // 关闭监听套接字 if (m_listen_fd != -1) { close(m_listen_fd); m_listen_fd = -1; } // 关闭所有客户端连接 for (auto& client : m_clients) { client->close(); } m_clients.clear(); } bool TcpServer::isRunning() const { return m_running; } uint16_t TcpServer::getPort() const { return m_port; } void TcpServer::setConnectionCallback(const ConnectionCallback& callback) { m_connection_callback = callback; } void TcpServer::setMessageCallback(const MessageCallback& callback) { m_message_callback = callback; } void TcpServer::setErrorCallback(const ErrorCallback& callback) { m_error_callback = callback; } void TcpServer::acceptThread() { while (m_running) { struct sockaddr_in client_addr; socklen_t client_addr_len = sizeof(client_addr); int client_fd = accept(m_listen_fd, (struct sockaddr*)&client_addr, &client_addr_len); if (client_fd == -1) { if (m_running) { if (m_error_callback) { // 这里没有具体的 TcpSocket 对象,所以传递一个空的 TcpSocket empty_socket; m_error_callback(empty_socket, ErrorCode::AcceptFailed); } } continue; } // 创建新的客户端套接字 auto client_socket = std::make_shared(client_fd); // 设置回调函数 client_socket->setMessageCallback([this](TcpSocket& socket, const std::vector& data) { if (m_message_callback) { m_message_callback(socket, data); } }); client_socket->setErrorCallback([this, client_socket](TcpSocket& socket, ErrorCode error) { // 从客户端列表中移除断开的连接 for (auto it = m_clients.begin(); it != m_clients.end(); ++it) { if (it->get() == &socket) { m_clients.erase(it); break; } } if (m_error_callback) { m_error_callback(socket, error); } }); // 添加到客户端列表 m_clients.push_back(client_socket); // 通知有新连接 if (m_connection_callback) { m_connection_callback(*this, client_socket); } // 开始接收该客户端的数据 client_socket->startReceiving(); } } // 错误处理辅助函数 std::string errorCodeToString(ErrorCode error) { switch (error) { case ErrorCode::Success: return "Success"; case ErrorCode::SocketCreationFailed: return "Socket creation failed"; case ErrorCode::BindFailed: return "Bind failed"; case ErrorCode::ListenFailed: return "Listen failed"; case ErrorCode::AcceptFailed: return "Accept failed"; case ErrorCode::ConnectFailed: return "Connect failed"; case ErrorCode::SendFailed: return "Send failed"; case ErrorCode::ReceiveFailed: return "Receive failed"; case ErrorCode::Disconnected: return "Disconnected"; case ErrorCode::InvalidArgument: return "Invalid argument"; default: return "Unknown error"; } } } // namespace tcp