diff --git a/tcp_utility.hpp b/tcp_utility.hpp new file mode 100644 index 0000000..19c5323 --- /dev/null +++ b/tcp_utility.hpp @@ -0,0 +1,154 @@ +#ifndef TCP_UTILITY_HPP +#define TCP_UTILITY_HPP + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tcp { + +// 错误码枚举 +enum class ErrorCode { + Success = 0, + SocketCreationFailed, + BindFailed, + ListenFailed, + AcceptFailed, + ConnectFailed, + SendFailed, + ReceiveFailed, + Disconnected, + InvalidArgument +}; + +// 回调函数类型定义 +using MessageCallback = std::function&)>; +using ErrorCallback = std::function; +using ConnectionCallback = std::function)>; + +// 前向声明 +class TcpSocket; +class TcpClient; +class TcpServer; + +// TCP套接字类 +class TcpSocket { +public: + TcpSocket(int socket_fd = -1); + ~TcpSocket(); + + // 禁止拷贝,允许移动 + TcpSocket(const TcpSocket&) = delete; + TcpSocket& operator=(const TcpSocket&) = delete; + TcpSocket(TcpSocket&&) noexcept; + TcpSocket& operator=(TcpSocket&&) noexcept; + + // 连接到服务器 + ErrorCode connect(const std::string& host, uint16_t port); + + // 发送数据 + ErrorCode send(const std::vector& data); + ErrorCode send(const std::string& data); + + // 接收数据 + ErrorCode receive(std::vector& buffer, size_t max_size); + + // 关闭连接 + void close(); + + // 获取状态 + bool isConnected() const; + std::string getPeerAddress() const; + uint16_t getPeerPort() const; + + // 设置回调函数 + void setMessageCallback(const MessageCallback& callback); + void setErrorCallback(const ErrorCallback& callback); + + // 开始异步接收 + void startReceiving(); + +private: + int m_socket_fd; + bool m_connected; + std::string m_peer_address; + uint16_t m_peer_port; + MessageCallback m_message_callback; + ErrorCallback m_error_callback; +}; + +// TCP客户端类 +class TcpClient { +public: + TcpClient(); + ~TcpClient(); + + // 连接到服务器 + ErrorCode connect(const std::string& host, uint16_t port); + + // 断开连接 + void disconnect(); + + // 发送数据 + ErrorCode send(const std::vector& data); + ErrorCode send(const std::string& data); + + // 获取状态 + bool isConnected() const; + std::shared_ptr getSocket() const; + + // 设置回调函数 + void setMessageCallback(const MessageCallback& callback); + void setErrorCallback(const ErrorCallback& callback); + +private: + std::shared_ptr m_socket; + MessageCallback m_message_callback; + ErrorCallback m_error_callback; +}; + +// TCP服务器类 +class TcpServer { +public: + TcpServer(); + ~TcpServer(); + + // 启动服务器 + ErrorCode start(uint16_t port, size_t max_connections = 10); + + // 停止服务器 + void stop(); + + // 获取状态 + bool isRunning() const; + uint16_t getPort() const; + + // 设置回调函数 + void setConnectionCallback(const ConnectionCallback& callback); + void setMessageCallback(const MessageCallback& callback); + void setErrorCallback(const ErrorCallback& callback); + +private: + int m_listen_fd; + uint16_t m_port; + bool m_running; + std::vector> m_clients; + ConnectionCallback m_connection_callback; + MessageCallback m_message_callback; + ErrorCallback m_error_callback; + + // 接受新连接的线程函数 + void acceptThread(); +}; + +// 错误处理辅助函数 +std::string errorCodeToString(ErrorCode error); + +} // namespace tcp + +#endif // TCP_UTILITY_HPP