#ifndef DQN_AGENT_H #define DQN_AGENT_H #include #include // LibTorch集成 - 可选依赖 #ifdef USE_LIBTORCH #include #include #endif /** * @brief DQN智能体基类 * * 用于加载训练好的PyTorch模型并进行推理 * 支持两种模式: * 1. 有LibTorch时: 使用TorchScript模型进行实时推理 * 2. 无LibTorch时: 降级到基于规则的策略 */ class DQNAgent { public: /** * @brief 构造函数 * @param model_path TorchScript模型文件路径 */ explicit DQNAgent(const std::string& model_path); ~DQNAgent(); /** * @brief 加载模型 * @return 是否加载成功 */ bool loadModel(const std::string& model_path); /** * @brief 预测动作 * @param state 状态向量 * @return 选择的动作索引 */ int predict(const std::vector& state); /** * @brief 批量预测 * @param states 状态向量列表 * @return Q值列表 */ std::vector predictBatch(const std::vector>& states); /** * @brief 检查模型是否可用 * @return 是否已加载模型且可用 */ bool isAvailable() const { return model_loaded_; } /** * @brief 获取动作空间大小 */ int getActionDim() const { return action_dim_; } private: #ifdef USE_LIBTORCH torch::jit::script::Module model_; torch::Device device_; #endif bool model_loaded_; int action_dim_; std::string model_path_; }; /** * @brief 路径规划DQN智能体 * * 专门用于路径规划的DQN智能体 * 状态空间: 35维 * 动作空间: K (候选路径数量) */ class PathPlanningDQNAgent { public: PathPlanningDQNAgent(const std::string& model_path, int num_candidates = 5); /** * @brief 从完整状态中选择最优路径 */ int selectPath( const std::vector& state // 35维状态 ); /** * @brief 从编码后的状态中选择最优路径 */ int selectPathEncoded( const std::vector& task_features, // 3维 const std::vector& agv_features, // 4维 const std::vector>& path_features, // K×4维 const std::vector& global_congestion, // 4维 const std::vector& time_features, // 2维 const std::vector& urgency_features // 2维 ); bool isAvailable() const { return agent_->isAvailable(); } private: std::unique_ptr agent_; int num_candidates_; }; /** * @brief 任务分配DQN智能体 * * 专门用于任务分配的DQN智能体 * 状态空间: ~85维 * 动作空间: N+1 (可用AGV数 + 拒绝) */ class TaskAssignmentDQNAgent { public: TaskAssignmentDQNAgent(const std::string& model_path, int num_agvs = 10); /** * @brief 选择AGV分配 */ int selectAGV( const std::vector& state // ~85维状态 ); bool isAvailable() const { return agent_->isAvailable(); } private: std::unique_ptr agent_; int num_agvs_; }; /** * @brief 基于规则的降级策略 * * 当LibTorch不可用时使用的简单策略 */ class RuleBasedFallback { public: // 路径规划: 选择最短且冲突最少的路径 static int selectPathByRules( const std::vector>& path_features // K×4 (长度, 冲突, 速度, 平滑度) ); // 任务分配: 选择最近且电量充足的AGV static int selectAGVByRules( const std::vector>& agv_features // N×5 (距离, 电量, 速度, 状态, 负载) ); }; #endif // DQN_AGENT_H