上传文件至 include/dispatch/rl
This commit is contained in:
152
include/dispatch/rl/dqn_agent.h
Normal file
152
include/dispatch/rl/dqn_agent.h
Normal file
@@ -0,0 +1,152 @@
|
||||
#ifndef DQN_AGENT_H
|
||||
#define DQN_AGENT_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// LibTorch集成 - 可选依赖
|
||||
#ifdef USE_LIBTORCH
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief DQN智能体基类
|
||||
*
|
||||
* 用于加载训练好的PyTorch模型并进行推理
|
||||
* 支持两种模式:
|
||||
* 1. 有LibTorch时: 使用TorchScript模型进行实时推理
|
||||
* 2. 无LibTorch时: 降级到基于规则的策略
|
||||
*/
|
||||
class DQNAgent {
|
||||
public:
|
||||
/**
|
||||
* @brief 构造函数
|
||||
* @param model_path TorchScript模型文件路径
|
||||
*/
|
||||
explicit DQNAgent(const std::string& model_path);
|
||||
|
||||
~DQNAgent();
|
||||
|
||||
/**
|
||||
* @brief 加载模型
|
||||
* @return 是否加载成功
|
||||
*/
|
||||
bool loadModel(const std::string& model_path);
|
||||
|
||||
/**
|
||||
* @brief 预测动作
|
||||
* @param state 状态向量
|
||||
* @return 选择的动作索引
|
||||
*/
|
||||
int predict(const std::vector<float>& state);
|
||||
|
||||
/**
|
||||
* @brief 批量预测
|
||||
* @param states 状态向量列表
|
||||
* @return Q值列表
|
||||
*/
|
||||
std::vector<float> predictBatch(const std::vector<std::vector<float>>& states);
|
||||
|
||||
/**
|
||||
* @brief 检查模型是否可用
|
||||
* @return 是否已加载模型且可用
|
||||
*/
|
||||
bool isAvailable() const { return model_loaded_; }
|
||||
|
||||
/**
|
||||
* @brief 获取动作空间大小
|
||||
*/
|
||||
int getActionDim() const { return action_dim_; }
|
||||
|
||||
private:
|
||||
#ifdef USE_LIBTORCH
|
||||
torch::jit::script::Module model_;
|
||||
torch::Device device_;
|
||||
#endif
|
||||
bool model_loaded_;
|
||||
int action_dim_;
|
||||
std::string model_path_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief 路径规划DQN智能体
|
||||
*
|
||||
* 专门用于路径规划的DQN智能体
|
||||
* 状态空间: 35维
|
||||
* 动作空间: K (候选路径数量)
|
||||
*/
|
||||
class PathPlanningDQNAgent {
|
||||
public:
|
||||
PathPlanningDQNAgent(const std::string& model_path, int num_candidates = 5);
|
||||
|
||||
/**
|
||||
* @brief 从完整状态中选择最优路径
|
||||
*/
|
||||
int selectPath(
|
||||
const std::vector<float>& state // 35维状态
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief 从编码后的状态中选择最优路径
|
||||
*/
|
||||
int selectPathEncoded(
|
||||
const std::vector<float>& task_features, // 3维
|
||||
const std::vector<float>& agv_features, // 4维
|
||||
const std::vector<std::vector<float>>& path_features, // K×4维
|
||||
const std::vector<float>& global_congestion, // 4维
|
||||
const std::vector<float>& time_features, // 2维
|
||||
const std::vector<float>& urgency_features // 2维
|
||||
);
|
||||
|
||||
bool isAvailable() const { return agent_->isAvailable(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DQNAgent> agent_;
|
||||
int num_candidates_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief 任务分配DQN智能体
|
||||
*
|
||||
* 专门用于任务分配的DQN智能体
|
||||
* 状态空间: ~85维
|
||||
* 动作空间: N+1 (可用AGV数 + 拒绝)
|
||||
*/
|
||||
class TaskAssignmentDQNAgent {
|
||||
public:
|
||||
TaskAssignmentDQNAgent(const std::string& model_path, int num_agvs = 10);
|
||||
|
||||
/**
|
||||
* @brief 选择AGV分配
|
||||
*/
|
||||
int selectAGV(
|
||||
const std::vector<float>& state // ~85维状态
|
||||
);
|
||||
|
||||
bool isAvailable() const { return agent_->isAvailable(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DQNAgent> agent_;
|
||||
int num_agvs_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief 基于规则的降级策略
|
||||
*
|
||||
* 当LibTorch不可用时使用的简单策略
|
||||
*/
|
||||
class RuleBasedFallback {
|
||||
public:
|
||||
// 路径规划: 选择最短且冲突最少的路径
|
||||
static int selectPathByRules(
|
||||
const std::vector<std::vector<float>>& path_features // K×4 (长度, 冲突, 速度, 平滑度)
|
||||
);
|
||||
|
||||
// 任务分配: 选择最近且电量充足的AGV
|
||||
static int selectAGVByRules(
|
||||
const std::vector<std::vector<float>>& agv_features // N×5 (距离, 电量, 速度, 状态, 负载)
|
||||
);
|
||||
};
|
||||
|
||||
#endif // DQN_AGENT_H
|
||||
150
include/dispatch/rl/rl_enhanced_agv_manager.h
Normal file
150
include/dispatch/rl/rl_enhanced_agv_manager.h
Normal file
@@ -0,0 +1,150 @@
|
||||
#ifndef RL_ENHANCED_AGV_MANAGER_H
|
||||
#define RL_ENHANCED_AGV_MANAGER_H
|
||||
|
||||
#include "dispatch/enhanced_agv_manager.h"
|
||||
#include "dispatch/rl/dqn_agent.h"
|
||||
#include "dispatch/rl/state_encoder.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
/**
|
||||
* @brief RL增强的AGV管理器
|
||||
*
|
||||
* 集成DQN强化学习智能体的增强AGV管理器
|
||||
* 在原有EnhancedAGVManager基础上添加:
|
||||
* 1. 路径规划DQN优化
|
||||
* 2. 任务分配DQN优化
|
||||
* 3. 自动降级到传统算法
|
||||
*/
|
||||
class RLEnhancedAGVManager : public EnhancedAGVManager {
|
||||
public:
|
||||
RLEnhancedAGVManager(GraphMap* map, ResourceManager* rm = nullptr);
|
||||
|
||||
// ========== RL功能开关 ==========
|
||||
|
||||
/**
|
||||
* @brief 启用路径规划DQN
|
||||
* @param model_path TorchScript模型文件路径
|
||||
* @param num_candidates 候选路径数量
|
||||
* @return 是否成功启用
|
||||
*/
|
||||
bool enableRLForPath(const std::string& model_path, int num_candidates = 5);
|
||||
|
||||
/**
|
||||
* @brief 启用任务分配DQN
|
||||
* @param model_path TorchScript模型文件路径
|
||||
* @param num_agvs 最大AGV数量
|
||||
* @return 是否成功启用
|
||||
*/
|
||||
bool enableRLForAssignment(const std::string& model_path, int num_agvs = 10);
|
||||
|
||||
/**
|
||||
* @brief 禁用RL功能
|
||||
*/
|
||||
void disableRL();
|
||||
|
||||
// ========== 重写的方法 ==========
|
||||
|
||||
/**
|
||||
* @brief 智能任务分配 (重写)
|
||||
* 如果启用了任务分配DQN, 使用RL选择; 否则使用父类方法
|
||||
*/
|
||||
int assignTasksSmart() override;
|
||||
|
||||
/**
|
||||
* @brief 查找路径 (重写)
|
||||
* 如果启用了路径规划DQN, 使用RL选择路径; 否则使用传统A*
|
||||
*/
|
||||
std::vector<Path*> findPathForTask(AGV* agv, Task* task);
|
||||
|
||||
// ========== RL特定的方法 ==========
|
||||
|
||||
/**
|
||||
* @brief 使用RL选择最优路径
|
||||
* @param agv 执行任务的AGV
|
||||
* @param task 任务
|
||||
* @return 最优路径
|
||||
*/
|
||||
std::vector<Path*> selectOptimalPathWithRL(AGV* agv, Task* task);
|
||||
|
||||
/**
|
||||
* @brief 使用RL选择AGV分配
|
||||
* @param task 待分配任务
|
||||
* @param available_agvs 可用AGV列表
|
||||
* @return 选中的AGV索引
|
||||
*/
|
||||
int selectAGVWithRL(Task* task, const std::vector<AGV*>& available_agvs);
|
||||
|
||||
// ========== 配置和统计 ==========
|
||||
|
||||
/**
|
||||
* @brief 设置候选路径数量
|
||||
*/
|
||||
void setNumCandidates(int num) { num_candidates_ = num; }
|
||||
|
||||
/**
|
||||
* @brief 路径规划RL是否可用
|
||||
*/
|
||||
bool isRLPathAvailable() const { return use_rl_for_path_ && path_dqn_ && path_dqn_->isAvailable(); }
|
||||
|
||||
/**
|
||||
* @brief 任务分配RL是否可用
|
||||
*/
|
||||
bool isRLAssignmentAvailable() const { return use_rl_for_assignment_ && assignment_dqn_ && assignment_dqn_->isAvailable(); }
|
||||
|
||||
/**
|
||||
* @brief 打印RL统计信息
|
||||
*/
|
||||
void printRLStatistics();
|
||||
|
||||
private:
|
||||
// RL智能体
|
||||
std::unique_ptr<PathPlanningDQNAgent> path_dqn_;
|
||||
std::unique_ptr<TaskAssignmentDQNAgent> assignment_dqn_;
|
||||
std::unique_ptr<StateEncoder> state_encoder_;
|
||||
|
||||
// 配置
|
||||
bool use_rl_for_path_;
|
||||
bool use_rl_for_assignment_;
|
||||
int num_candidates_;
|
||||
int num_agvs_;
|
||||
|
||||
// 统计
|
||||
struct RLStatistics {
|
||||
int path_decisions = 0;
|
||||
int path_fallbacks = 0;
|
||||
int assignment_decisions = 0;
|
||||
int assignment_fallbacks = 0;
|
||||
|
||||
void reset() {
|
||||
path_decisions = 0;
|
||||
path_fallbacks = 0;
|
||||
assignment_decisions = 0;
|
||||
assignment_fallbacks = 0;
|
||||
}
|
||||
} rl_stats_;
|
||||
|
||||
// ========== 内部辅助方法 ==========
|
||||
|
||||
/**
|
||||
* @brief 执行RL任务分配
|
||||
*/
|
||||
int assignTasksWithRL();
|
||||
|
||||
/**
|
||||
* @brief 分配单个任务
|
||||
*/
|
||||
bool assignSingleTask(AGV* agv, Task* task);
|
||||
|
||||
/**
|
||||
* @brief 获取可用AGV列表
|
||||
*/
|
||||
std::vector<AGV*> getAvailableAGVs();
|
||||
|
||||
/**
|
||||
* @brief 获取待处理任务列表
|
||||
*/
|
||||
std::vector<Task*> getPendingTasks();
|
||||
};
|
||||
|
||||
#endif // RL_ENHANCED_AGV_MANAGER_H
|
||||
123
include/dispatch/rl/state_encoder.h
Normal file
123
include/dispatch/rl/state_encoder.h
Normal file
@@ -0,0 +1,123 @@
|
||||
#ifndef STATE_ENCODER_H
|
||||
#define STATE_ENCODER_H
|
||||
|
||||
#include "dispatch/graph_map.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
// 前向声明
|
||||
class ResourceManager;
|
||||
class AGV;
|
||||
|
||||
/**
|
||||
* @brief 任务结构体 (简化版)
|
||||
*/
|
||||
struct Task {
|
||||
int id;
|
||||
int start_point_id;
|
||||
int end_point_id;
|
||||
int priority;
|
||||
double deadline;
|
||||
double time_remaining;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief 状态编码器
|
||||
*
|
||||
* 将C++对象编码为DQN可用的状态向量
|
||||
*/
|
||||
class StateEncoder {
|
||||
public:
|
||||
StateEncoder(GraphMap* map, ResourceManager* rm);
|
||||
|
||||
// ========== 路径规划状态编码 (35维) ==========
|
||||
/**
|
||||
* @brief 编码路径规划状态
|
||||
*
|
||||
* 状态组成 (35维):
|
||||
* - 任务信息 (3维): 起点(x,y)、终点(x)
|
||||
* - AGV状态 (4维): 位置(x,y)、速度、电量
|
||||
* - 候选路径特征 (20维): 5条路径 × 4特征
|
||||
* - 全局拥堵 (4维): 各区域占用AGV数
|
||||
* - 时间信息 (2维): 当前时段、历史拥堵预测
|
||||
* - 紧急度 (2维): 任务优先级、剩余时间
|
||||
*/
|
||||
std::vector<float> encodePathPlanningState(
|
||||
AGV* agv,
|
||||
Task* task,
|
||||
const std::vector<std::vector<Path*>>& candidates
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief 编码单条路径特征
|
||||
*/
|
||||
std::vector<float> encodeSinglePathFeatures(
|
||||
const std::vector<Path*>& path
|
||||
);
|
||||
|
||||
// ========== 任务分配状态编码 (~85维) ==========
|
||||
/**
|
||||
* @brief 编码任务分配状态
|
||||
*
|
||||
* 状态组成 (~85维):
|
||||
* - 任务特征 (8维): 起终点坐标、距离、优先级、截止时间、类型
|
||||
* - 可用AGV池 (50维): 10辆AGV × 5特征
|
||||
* - 全局状态 (7维): 等待任务数、执行中任务数等
|
||||
* - 资源状态 (20维): 关键路径占用情况
|
||||
* - 时间特征 (2维): 时段、历史成功率
|
||||
*/
|
||||
std::vector<float> encodeDispatchState(
|
||||
Task* task,
|
||||
const std::vector<AGV*>& available_agvs
|
||||
);
|
||||
|
||||
/**
|
||||
* @brief 编码单个AGV特征
|
||||
*/
|
||||
std::vector<float> encodeSingleAGVFeatures(
|
||||
AGV* agv,
|
||||
Point* task_start
|
||||
);
|
||||
|
||||
// ========== 工具函数 ==========
|
||||
/**
|
||||
* @brief 归一化到[0, 1]
|
||||
*/
|
||||
static float normalize(float value, float min, float max) {
|
||||
if (max - min < 1e-6f) return 0.5f;
|
||||
return (value - min) / (max - min);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 裁剪到[0, 1]
|
||||
*/
|
||||
static float clip(float value) {
|
||||
return std::max(0.0f, std::min(1.0f, value));
|
||||
}
|
||||
|
||||
private:
|
||||
GraphMap* map_;
|
||||
ResourceManager* resource_manager_;
|
||||
|
||||
// 地图边界
|
||||
float map_min_x_, map_max_x_;
|
||||
float map_min_y_, map_max_y_;
|
||||
float max_path_length_;
|
||||
|
||||
/**
|
||||
* @brief 初始化地图边界
|
||||
*/
|
||||
void initializeMapBounds();
|
||||
|
||||
/**
|
||||
* @brief 计算区域拥堵情况
|
||||
*/
|
||||
std::vector<int> calculateZoneOccupancy(int num_zones = 4);
|
||||
|
||||
/**
|
||||
* @brief 计算路径平滑度
|
||||
*/
|
||||
float calculatePathSmoothness(const std::vector<Path*>& path);
|
||||
};
|
||||
|
||||
#endif // STATE_ENCODER_H
|
||||
Reference in New Issue
Block a user