153 lines
3.8 KiB
C++
153 lines
3.8 KiB
C++
#ifndef DQN_AGENT_H
|
||
#define DQN_AGENT_H
|
||
|
||
#include <string>
|
||
#include <vector>
|
||
|
||
// LibTorch集成 - 可选依赖
|
||
#ifdef USE_LIBTORCH
|
||
#include <torch/torch.h>
|
||
#include <torch/script.h>
|
||
#endif
|
||
|
||
/**
|
||
* @brief DQN智能体基类
|
||
*
|
||
* 用于加载训练好的PyTorch模型并进行推理
|
||
* 支持两种模式:
|
||
* 1. 有LibTorch时: 使用TorchScript模型进行实时推理
|
||
* 2. 无LibTorch时: 降级到基于规则的策略
|
||
*/
|
||
class DQNAgent {
|
||
public:
|
||
/**
|
||
* @brief 构造函数
|
||
* @param model_path TorchScript模型文件路径
|
||
*/
|
||
explicit DQNAgent(const std::string& model_path);
|
||
|
||
~DQNAgent();
|
||
|
||
/**
|
||
* @brief 加载模型
|
||
* @return 是否加载成功
|
||
*/
|
||
bool loadModel(const std::string& model_path);
|
||
|
||
/**
|
||
* @brief 预测动作
|
||
* @param state 状态向量
|
||
* @return 选择的动作索引
|
||
*/
|
||
int predict(const std::vector<float>& state);
|
||
|
||
/**
|
||
* @brief 批量预测
|
||
* @param states 状态向量列表
|
||
* @return Q值列表
|
||
*/
|
||
std::vector<float> predictBatch(const std::vector<std::vector<float>>& states);
|
||
|
||
/**
|
||
* @brief 检查模型是否可用
|
||
* @return 是否已加载模型且可用
|
||
*/
|
||
bool isAvailable() const { return model_loaded_; }
|
||
|
||
/**
|
||
* @brief 获取动作空间大小
|
||
*/
|
||
int getActionDim() const { return action_dim_; }
|
||
|
||
private:
|
||
#ifdef USE_LIBTORCH
|
||
torch::jit::script::Module model_;
|
||
torch::Device device_;
|
||
#endif
|
||
bool model_loaded_;
|
||
int action_dim_;
|
||
std::string model_path_;
|
||
};
|
||
|
||
/**
|
||
* @brief 路径规划DQN智能体
|
||
*
|
||
* 专门用于路径规划的DQN智能体
|
||
* 状态空间: 35维
|
||
* 动作空间: K (候选路径数量)
|
||
*/
|
||
class PathPlanningDQNAgent {
|
||
public:
|
||
PathPlanningDQNAgent(const std::string& model_path, int num_candidates = 5);
|
||
|
||
/**
|
||
* @brief 从完整状态中选择最优路径
|
||
*/
|
||
int selectPath(
|
||
const std::vector<float>& state // 35维状态
|
||
);
|
||
|
||
/**
|
||
* @brief 从编码后的状态中选择最优路径
|
||
*/
|
||
int selectPathEncoded(
|
||
const std::vector<float>& task_features, // 3维
|
||
const std::vector<float>& agv_features, // 4维
|
||
const std::vector<std::vector<float>>& path_features, // K×4维
|
||
const std::vector<float>& global_congestion, // 4维
|
||
const std::vector<float>& time_features, // 2维
|
||
const std::vector<float>& urgency_features // 2维
|
||
);
|
||
|
||
bool isAvailable() const { return agent_->isAvailable(); }
|
||
|
||
private:
|
||
std::unique_ptr<DQNAgent> agent_;
|
||
int num_candidates_;
|
||
};
|
||
|
||
/**
|
||
* @brief 任务分配DQN智能体
|
||
*
|
||
* 专门用于任务分配的DQN智能体
|
||
* 状态空间: ~85维
|
||
* 动作空间: N+1 (可用AGV数 + 拒绝)
|
||
*/
|
||
class TaskAssignmentDQNAgent {
|
||
public:
|
||
TaskAssignmentDQNAgent(const std::string& model_path, int num_agvs = 10);
|
||
|
||
/**
|
||
* @brief 选择AGV分配
|
||
*/
|
||
int selectAGV(
|
||
const std::vector<float>& state // ~85维状态
|
||
);
|
||
|
||
bool isAvailable() const { return agent_->isAvailable(); }
|
||
|
||
private:
|
||
std::unique_ptr<DQNAgent> agent_;
|
||
int num_agvs_;
|
||
};
|
||
|
||
/**
|
||
* @brief 基于规则的降级策略
|
||
*
|
||
* 当LibTorch不可用时使用的简单策略
|
||
*/
|
||
class RuleBasedFallback {
|
||
public:
|
||
// 路径规划: 选择最短且冲突最少的路径
|
||
static int selectPathByRules(
|
||
const std::vector<std::vector<float>>& path_features // K×4 (长度, 冲突, 速度, 平滑度)
|
||
);
|
||
|
||
// 任务分配: 选择最近且电量充足的AGV
|
||
static int selectAGVByRules(
|
||
const std::vector<std::vector<float>>& agv_features // N×5 (距离, 电量, 速度, 状态, 负载)
|
||
);
|
||
};
|
||
|
||
#endif // DQN_AGENT_H
|