Files
RCS-3000/include/dispatch/rl/dqn_agent.h

153 lines
3.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef DQN_AGENT_H
#define DQN_AGENT_H
#include <string>
#include <vector>
// LibTorch集成 - 可选依赖
#ifdef USE_LIBTORCH
#include <torch/torch.h>
#include <torch/script.h>
#endif
/**
* @brief DQN智能体基类
*
* 用于加载训练好的PyTorch模型并进行推理
* 支持两种模式:
* 1. 有LibTorch时: 使用TorchScript模型进行实时推理
* 2. 无LibTorch时: 降级到基于规则的策略
*/
class DQNAgent {
public:
/**
* @brief 构造函数
* @param model_path TorchScript模型文件路径
*/
explicit DQNAgent(const std::string& model_path);
~DQNAgent();
/**
* @brief 加载模型
* @return 是否加载成功
*/
bool loadModel(const std::string& model_path);
/**
* @brief 预测动作
* @param state 状态向量
* @return 选择的动作索引
*/
int predict(const std::vector<float>& state);
/**
* @brief 批量预测
* @param states 状态向量列表
* @return Q值列表
*/
std::vector<float> predictBatch(const std::vector<std::vector<float>>& states);
/**
* @brief 检查模型是否可用
* @return 是否已加载模型且可用
*/
bool isAvailable() const { return model_loaded_; }
/**
* @brief 获取动作空间大小
*/
int getActionDim() const { return action_dim_; }
private:
#ifdef USE_LIBTORCH
torch::jit::script::Module model_;
torch::Device device_;
#endif
bool model_loaded_;
int action_dim_;
std::string model_path_;
};
/**
* @brief 路径规划DQN智能体
*
* 专门用于路径规划的DQN智能体
* 状态空间: 35维
* 动作空间: K (候选路径数量)
*/
class PathPlanningDQNAgent {
public:
PathPlanningDQNAgent(const std::string& model_path, int num_candidates = 5);
/**
* @brief 从完整状态中选择最优路径
*/
int selectPath(
const std::vector<float>& state // 35维状态
);
/**
* @brief 从编码后的状态中选择最优路径
*/
int selectPathEncoded(
const std::vector<float>& task_features, // 3维
const std::vector<float>& agv_features, // 4维
const std::vector<std::vector<float>>& path_features, // K×4维
const std::vector<float>& global_congestion, // 4维
const std::vector<float>& time_features, // 2维
const std::vector<float>& urgency_features // 2维
);
bool isAvailable() const { return agent_->isAvailable(); }
private:
std::unique_ptr<DQNAgent> agent_;
int num_candidates_;
};
/**
* @brief 任务分配DQN智能体
*
* 专门用于任务分配的DQN智能体
* 状态空间: ~85维
* 动作空间: N+1 (可用AGV数 + 拒绝)
*/
class TaskAssignmentDQNAgent {
public:
TaskAssignmentDQNAgent(const std::string& model_path, int num_agvs = 10);
/**
* @brief 选择AGV分配
*/
int selectAGV(
const std::vector<float>& state // ~85维状态
);
bool isAvailable() const { return agent_->isAvailable(); }
private:
std::unique_ptr<DQNAgent> agent_;
int num_agvs_;
};
/**
* @brief 基于规则的降级策略
*
* 当LibTorch不可用时使用的简单策略
*/
class RuleBasedFallback {
public:
// 路径规划: 选择最短且冲突最少的路径
static int selectPathByRules(
const std::vector<std::vector<float>>& path_features // K×4 (长度, 冲突, 速度, 平滑度)
);
// 任务分配: 选择最近且电量充足的AGV
static int selectAGVByRules(
const std::vector<std::vector<float>>& agv_features // N×5 (, , , , )
);
};
#endif // DQN_AGENT_H