vision/tcp_utility.cpp
2025-06-25 13:14:18 +08:00

415 lines
11 KiB
C++

#include "tcp_utility.hpp"
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <cstring>
#include <iostream>
#include <thread>
#include <netdb.h>
namespace tcp {
// TcpSocket 实现
TcpSocket::TcpSocket(int socket_fd)
: m_socket_fd(socket_fd), m_connected(socket_fd != -1) {
if (m_connected) {
// 获取对端地址和端口
struct sockaddr_in addr;
socklen_t addr_len = sizeof(addr);
if (getpeername(m_socket_fd, (struct sockaddr*)&addr, &addr_len) == 0) {
m_peer_address = inet_ntoa(addr.sin_addr);
m_peer_port = ntohs(addr.sin_port);
}
}
}
TcpSocket::~TcpSocket() {
close();
}
TcpSocket::TcpSocket(TcpSocket&& other) noexcept
: m_socket_fd(other.m_socket_fd),
m_connected(other.m_connected),
m_peer_address(std::move(other.m_peer_address)),
m_peer_port(other.m_peer_port),
m_message_callback(std::move(other.m_message_callback)),
m_error_callback(std::move(other.m_error_callback)) {
other.m_socket_fd = -1;
other.m_connected = false;
}
TcpSocket& TcpSocket::operator=(TcpSocket&& other) noexcept {
if (this != &other) {
close();
m_socket_fd = other.m_socket_fd;
m_connected = other.m_connected;
m_peer_address = std::move(other.m_peer_address);
m_peer_port = other.m_peer_port;
m_message_callback = std::move(other.m_message_callback);
m_error_callback = std::move(other.m_error_callback);
other.m_socket_fd = -1;
other.m_connected = false;
}
return *this;
}
ErrorCode TcpSocket::connect(const std::string& host, uint16_t port) {
close();
// 创建套接字
m_socket_fd = socket(AF_INET, SOCK_STREAM, 0);
if (m_socket_fd == -1) {
return ErrorCode::SocketCreationFailed;
}
// 设置服务器地址
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
// 解析主机名
struct hostent* host_entry = gethostbyname(host.c_str());
if (host_entry == nullptr) {
close();
return ErrorCode::ConnectFailed;
}
memcpy(&server_addr.sin_addr, host_entry->h_addr_list[0], host_entry->h_length);
// 连接服务器
if (::connect(m_socket_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) {
close();
return ErrorCode::ConnectFailed;
}
m_connected = true;
m_peer_address = host;
m_peer_port = port;
return ErrorCode::Success;
}
ErrorCode TcpSocket::send(const std::vector<unsigned char>& data) {
if (!m_connected || m_socket_fd == -1) {
return ErrorCode::Disconnected;
}
size_t total_sent = 0;
while (total_sent < data.size()) {
ssize_t sent = ::send(m_socket_fd, data.data() + total_sent,
data.size() - total_sent, 0);
if (sent == -1) {
close();
if (m_error_callback) {
m_error_callback(*this, ErrorCode::SendFailed);
}
return ErrorCode::SendFailed;
}
total_sent += sent;
}
return ErrorCode::Success;
}
ErrorCode TcpSocket::send(const std::string& data) {
return send(std::vector<unsigned char>(data.begin(), data.end()));
}
ErrorCode TcpSocket::receive(std::vector<unsigned char>& buffer, size_t max_size) {
if (!m_connected || m_socket_fd == -1) {
return ErrorCode::Disconnected;
}
buffer.resize(max_size);
ssize_t received = ::recv(m_socket_fd, buffer.data(), max_size, 0);
if (received > 0) {
buffer.resize(received);
return ErrorCode::Success;
} else if (received == 0) {
// 连接关闭
close();
if (m_error_callback) {
m_error_callback(*this, ErrorCode::Disconnected);
}
return ErrorCode::Disconnected;
} else {
// 错误发生
close();
if (m_error_callback) {
m_error_callback(*this, ErrorCode::ReceiveFailed);
}
return ErrorCode::ReceiveFailed;
}
}
void TcpSocket::close() {
if (m_socket_fd != -1) {
::close(m_socket_fd);
m_socket_fd = -1;
m_connected = false;
}
}
bool TcpSocket::isConnected() const {
return m_connected;
}
std::string TcpSocket::getPeerAddress() const {
return m_peer_address;
}
uint16_t TcpSocket::getPeerPort() const {
return m_peer_port;
}
void TcpSocket::setMessageCallback(const MessageCallback& callback) {
m_message_callback = callback;
}
void TcpSocket::setErrorCallback(const ErrorCallback& callback) {
m_error_callback = callback;
}
void TcpSocket::startReceiving() {
if (!m_connected || m_socket_fd == -1) {
return;
}
std::thread([this]() {
std::vector<unsigned char> buffer(4096);
while (isConnected()) {
ErrorCode result = receive(buffer, buffer.capacity());
if (result == ErrorCode::Success && m_message_callback) {
m_message_callback(*this, buffer);
} else if (result != ErrorCode::Success) {
break;
}
}
}).detach();
}
// TcpClient 实现
TcpClient::TcpClient() {
m_socket = std::make_shared<TcpSocket>();
}
TcpClient::~TcpClient() {
disconnect();
}
ErrorCode TcpClient::connect(const std::string& host, uint16_t port) {
ErrorCode result = m_socket->connect(host, port);
if (result == ErrorCode::Success) {
m_socket->setMessageCallback(m_message_callback);
m_socket->setErrorCallback(m_error_callback);
m_socket->startReceiving();
}
return result;
}
void TcpClient::disconnect() {
m_socket->close();
}
ErrorCode TcpClient::send(const std::vector<unsigned char>& data) {
return m_socket->send(data);
}
ErrorCode TcpClient::send(const std::string& data) {
return m_socket->send(data);
}
bool TcpClient::isConnected() const {
return m_socket->isConnected();
}
std::shared_ptr<TcpSocket> TcpClient::getSocket() const {
return m_socket;
}
void TcpClient::setMessageCallback(const MessageCallback& callback) {
m_message_callback = callback;
if (m_socket) {
m_socket->setMessageCallback(callback);
}
}
void TcpClient::setErrorCallback(const ErrorCallback& callback) {
m_error_callback = callback;
if (m_socket) {
m_socket->setErrorCallback(callback);
}
}
// TcpServer 实现
TcpServer::TcpServer()
: m_listen_fd(-1), m_port(0), m_running(false) {
}
TcpServer::~TcpServer() {
stop();
}
ErrorCode TcpServer::start(uint16_t port, size_t max_connections) {
if (m_running) {
return ErrorCode::Success;
}
// 创建套接字
m_listen_fd = socket(AF_INET, SOCK_STREAM, 0);
if (m_listen_fd == -1) {
return ErrorCode::SocketCreationFailed;
}
// 设置套接字选项
int opt = 1;
if (setsockopt(m_listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
close(m_listen_fd);
m_listen_fd = -1;
return ErrorCode::SocketCreationFailed;
}
// 绑定地址
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(port);
if (bind(m_listen_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) {
close(m_listen_fd);
m_listen_fd = -1;
return ErrorCode::BindFailed;
}
// 监听连接
if (listen(m_listen_fd, max_connections) == -1) {
close(m_listen_fd);
m_listen_fd = -1;
return ErrorCode::ListenFailed;
}
m_port = port;
m_running = true;
// 启动接受连接的线程
std::thread(&TcpServer::acceptThread, this).detach();
return ErrorCode::Success;
}
void TcpServer::stop() {
if (!m_running) {
return;
}
m_running = false;
// 关闭监听套接字
if (m_listen_fd != -1) {
close(m_listen_fd);
m_listen_fd = -1;
}
// 关闭所有客户端连接
for (auto& client : m_clients) {
client->close();
}
m_clients.clear();
}
bool TcpServer::isRunning() const {
return m_running;
}
uint16_t TcpServer::getPort() const {
return m_port;
}
void TcpServer::setConnectionCallback(const ConnectionCallback& callback) {
m_connection_callback = callback;
}
void TcpServer::setMessageCallback(const MessageCallback& callback) {
m_message_callback = callback;
}
void TcpServer::setErrorCallback(const ErrorCallback& callback) {
m_error_callback = callback;
}
void TcpServer::acceptThread() {
while (m_running) {
struct sockaddr_in client_addr;
socklen_t client_addr_len = sizeof(client_addr);
int client_fd = accept(m_listen_fd, (struct sockaddr*)&client_addr, &client_addr_len);
if (client_fd == -1) {
if (m_running) {
if (m_error_callback) {
// 这里没有具体的 TcpSocket 对象,所以传递一个空的
TcpSocket empty_socket;
m_error_callback(empty_socket, ErrorCode::AcceptFailed);
}
}
continue;
}
// 创建新的客户端套接字
auto client_socket = std::make_shared<TcpSocket>(client_fd);
// 设置回调函数
client_socket->setMessageCallback([this](TcpSocket& socket, const std::vector<unsigned char>& data) {
if (m_message_callback) {
m_message_callback(socket, data);
}
});
client_socket->setErrorCallback([this, client_socket](TcpSocket& socket, ErrorCode error) {
// 从客户端列表中移除断开的连接
for (auto it = m_clients.begin(); it != m_clients.end(); ++it) {
if (it->get() == &socket) {
m_clients.erase(it);
break;
}
}
if (m_error_callback) {
m_error_callback(socket, error);
}
});
// 添加到客户端列表
m_clients.push_back(client_socket);
// 通知有新连接
if (m_connection_callback) {
m_connection_callback(*this, client_socket);
}
// 开始接收该客户端的数据
client_socket->startReceiving();
}
}
// 错误处理辅助函数
std::string errorCodeToString(ErrorCode error) {
switch (error) {
case ErrorCode::Success: return "Success";
case ErrorCode::SocketCreationFailed: return "Socket creation failed";
case ErrorCode::BindFailed: return "Bind failed";
case ErrorCode::ListenFailed: return "Listen failed";
case ErrorCode::AcceptFailed: return "Accept failed";
case ErrorCode::ConnectFailed: return "Connect failed";
case ErrorCode::SendFailed: return "Send failed";
case ErrorCode::ReceiveFailed: return "Receive failed";
case ErrorCode::Disconnected: return "Disconnected";
case ErrorCode::InvalidArgument: return "Invalid argument";
default: return "Unknown error";
}
}
} // namespace tcp