# 这个代码用于计算从起始点s到假目标gi再到dcstar最后到gr的策略
import os
import numpy as np
import matplotlib.pyplot as plt
from policy_iteration_APF import PolicyIterationPlanner, load_policy
from policy_iteration_dcstar import dcstar_PolicyIterationPlanner
import seaborn as sns
from matplotlib.gridspec import GridSpec

def save_trajectory_to_file(path, filename):
    """将轨迹保存到文本文件中"""
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w') as f:
        f.write("# Trajectory coordinates (x, y)\n")
        for point in path:
            f.write(f"{point[0]}, {point[1]}\n")

def load_states_from_file(filename):
    """从文本文件中读取D、dcstar、gi_index"""
    def parse_coord(s):
        x, y = s.strip('()').split(',')
        return (int(x), int(y))
    
    D = []
    dcstar = None
    gi_index = None
    current_list = None
    
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('D ='):
                current_list = D
            elif line.startswith('dcstar ='):
                coords = line.replace('dcstar =', '').strip()
                if coords and coords != '[]':
                    dcstar = parse_coord(coords)
                continue
            elif line.startswith('gi_index ='):
                value = line.replace('gi_index =', '').strip()
                gi_index = int(value)
                continue
            elif line and line != '[' and line != ']':
                coords = line.strip(' ,').split(', ')
                for coord in coords:
                    if coord:
                        if current_list is not None:
                            current_list.append(parse_coord(coord))
    return D, dcstar, gi_index

class ValueDeceptiveModelPlannerExtended:
    def __init__(self, true_planner, fake_planners, dcstar_planner, delta=0.0, gi_index=0):
        self.true_planner = true_planner
        self.fake_planners = fake_planners
        self.dcstar_planner = dcstar_planner
        self.delta = delta
        self.observation_sequence = []
        self.probability_history = {'true': [], **{f'fake_{i}': [] for i in range(len(fake_planners))}}
        self.steps = 0

    def calculate_value_improvement(self, planner, state, action_idx):
        """计算某个动作带来的价值提升"""
        action = planner.actions[action_idx]
        next_state = planner.get_next_state(state, action)
        current_value = planner.value_function[state[1], state[0]]
        next_value = planner.value_function[next_state[1], next_state[0]]
        return next_value - current_value

    def prune_actions(self, state, planner):
        """获取有效的动作集合(必须保证价值有提升)"""
        valid_actions = []
        for action_idx in range(len(planner.actions)):
            if not planner.is_valid_action(state, planner.actions[action_idx]):
                continue
            if self.calculate_value_improvement(planner, state, action_idx) > 0:
                valid_actions.append(action_idx)
        return valid_actions

    def calculate_goal_probabilities_state_value(self, observation):
        """使用状态值函数计算目标概率"""
        if not observation:
            return {'true': 1.0/(len(self.fake_planners) + 1), 
                   **{f'fake_{i}': 1.0/(len(self.fake_planners) + 1) 
                      for i in range(len(self.fake_planners))}}

        goal_differences = {}
        total_exp = 0
        
        obs_state, obs_action = observation[-1]
        next_state = tuple(np.array(obs_state) + self.true_planner.actions[obs_action])
        state_value_diff = (self.true_planner.value_function[next_state[1], next_state[0]] - 
                          self.true_planner.value_function[self.true_planner.start[1], 
                                                         self.true_planner.start[0]])
        goal_differences['true'] = np.exp(state_value_diff)
        total_exp += goal_differences['true']
        
        for i, fake_planner in enumerate(self.fake_planners):
            next_state = tuple(np.array(obs_state) + fake_planner.actions[obs_action])
            state_value_diff = (fake_planner.value_function[next_state[1], next_state[0]] -
                              fake_planner.value_function[self.true_planner.start[1],
                                                        self.true_planner.start[0]])
            goal_differences[f'fake_{i}'] = np.exp(state_value_diff)
            total_exp += goal_differences[f'fake_{i}']
            
        return {goal: diff / total_exp for goal, diff in goal_differences.items()}

    def update_probability_history(self, state, action):
        """更新概率历史记录"""
        self.observation_sequence.append((state, action))
        self.steps += 1
        probs = self.calculate_goal_probabilities_state_value(self.observation_sequence)
        for goal_type in self.probability_history:
            self.probability_history[goal_type].append(probs.get(goal_type, 0))

    def select_action(self, state, stage, gi_index):
        """
        根据当前阶段选择动作
        stage = 1: 从起点到gi，最大化到gi的价值
        stage = 2: 从gi到dcstar，最大化到dcstar的价值
        stage = 3: 从dcstar到gr，最大化到gr的价值
        """
        if stage == 1:  # 去往假目标gi
            planner = self.fake_planners[gi_index]
        elif stage == 2:  # 去往dcstar
            planner = self.dcstar_planner
        else:  # 去往真实目标gr
            planner = self.true_planner

        valid_actions = self.prune_actions(state, planner)
        if not valid_actions:
            print("No valid actions available.")
            return 0

        max_improvement = float('-inf')
        best_action = valid_actions[0]
        for action_idx in valid_actions:
            improvement = self.calculate_value_improvement(planner, state, action_idx)
            if improvement > max_improvement:
                max_improvement = improvement
                best_action = action_idx

        return best_action

    def plan_path(self, start_state, dcstar, gi_index):
        """生成三段式路径：s -> gi -> dcstar -> gr"""
        current_state = start_state
        self.observation_sequence = []
        path = [current_state]
        fake_goal = tuple(self.fake_planners[gi_index].goal)

        # 第一段：从起点到假目标gi
        while current_state != fake_goal:
            action_idx = self.select_action(current_state, stage=1, gi_index=gi_index)
            action = self.fake_planners[gi_index].actions[action_idx]
            self.update_probability_history(current_state, action_idx)
            next_state = self.fake_planners[gi_index].get_next_state(current_state, action)
            if next_state == current_state:
                break
            current_state = next_state
            path.append(current_state)
            if len(path) > self.fake_planners[gi_index].width * self.fake_planners[gi_index].height:
                break

        # 第二段：从假目标gi到dcstar
        while current_state != dcstar:
            action_idx = self.select_action(current_state, stage=2, gi_index=gi_index)
            action = self.dcstar_planner.actions[action_idx]
            self.update_probability_history(current_state, action_idx)
            next_state = self.dcstar_planner.get_next_state(current_state, action)
            if next_state == current_state:
                break
            current_state = next_state
            path.append(current_state)
            if len(path) > self.dcstar_planner.width * self.dcstar_planner.height:
                break

        # 第三段：从dcstar到真实目标gr
        while current_state != tuple(self.true_planner.goal):
            action_idx = self.select_action(current_state, stage=3, gi_index=gi_index)
            action = self.true_planner.actions[action_idx]
            self.update_probability_history(current_state, action_idx)
            next_state = self.true_planner.get_next_state(current_state, action)
            if next_state == current_state:
                break
            current_state = next_state
            path.append(current_state)
            if len(path) > self.true_planner.width * self.true_planner.height:
                break

        return path

def visualize_probability_evolution(planner, save_path='goal_probability_evolution.png'):
    """创建并保存目标概率演变的折线图"""
    plt.figure(figsize=(12, 8))
    total_steps = len(planner.probability_history['true'])
    x = np.linspace(0, 100, total_steps)
    
    # 设置颜色方案 - matplotlib默认颜色循环
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    
    plt.plot(x, planner.probability_history['true'], 
            label='True Goal', color=colors[0],
            linewidth=2.5, marker='o', markersize=8)
    
    for i in range(len(planner.fake_planners)):
        plt.plot(x, planner.probability_history[f'fake_{i}'],
                label=f'Fake Goal {i+1}', color=colors[i+1],
                linewidth=2.5, marker='s', markersize=8)
    
    plt.xlabel('Path Progress (%)', fontsize=20)
    plt.ylabel('Probability', fontsize=20)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=20)
    plt.tick_params(labelsize=20)
    plt.ylim(-0.05, 1.05)
    plt.xlim(-2, 102)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def visualize_path(map_env, path, start_point, true_goal, fake_goals, title="E-VDM(g*) Path", save_path='path.png'):
    """可视化路径并保存"""
    plt.figure(figsize=(12, 10))
    ax = plt.gca()
    ax.set_facecolor('#f0f0f0')
    
    height, width = map_env.shape
    for i in range(width + 1):
        ax.axvline(i - 0.5, color='#dee2e6', linewidth=0.5)
    for i in range(height + 1):
        ax.axhline(i - 0.5, color='#dee2e6', linewidth=0.5)
    
    for y in range(height):
        for x in range(width):
            if map_env[y, x]:
                rect = plt.Rectangle(
                    (x-0.5, y-0.5), 1, 1,
                    facecolor='#343a40',
                    edgecolor='#212529',
                    linewidth=1,
                    alpha=0.8
                )
                ax.add_patch(rect)
    
    if path and len(path) > 1:
        path_array = np.array(path)
        plt.plot(path_array[:, 0], path_array[:, 1], 
                '-', color='#d62728', linewidth=2.5)
        plt.scatter(path_array[:, 0], path_array[:, 1], 
                   color='#d62728', s=50, zorder=5)
    
    plt.plot(start_point[0], start_point[1], 'o', color='#ff7f0e',
             markersize=15, label='Start', zorder=6)
    plt.plot(true_goal[0], true_goal[1], 'o', color='#1f77b4',
             markersize=15, label='True Goal', zorder=6)
    
    # 设置颜色方案 - matplotlib默认颜色循环
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    
    for i, fake_goal in enumerate(fake_goals):
        plt.plot(fake_goal[0], fake_goal[1], 'o', 
                color=colors[i+1],
                markersize=15, 
                label=f'Fake Goal {i+1}', 
                zorder=6)
    
    plt.title(title, pad=15, fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper right', fontsize=12)
    plt.xlim(-0.5, width-0.5)
    plt.ylim(-0.5, height-0.5)
    plt.gca().set_aspect('equal')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def main():
    # 从文件读取问题配置
    with open('problem_configs.txt', 'r', encoding='utf-8') as f:
        config_content = f.read()
    config_content = config_content.replace("problem_configs = ", "")
    problem_configs = eval(config_content)
    
    # 为每个地图生成策略
    for problem_name, config in problem_configs.items():
        map_name = config["map"]
        print(f"\nProcessing problem: {problem_name} with map {map_name}")
        start = config["start"]
        true_goal = config["true_goal"]
        fake_goals = config["fake_goals"]
    
        # 加载真目标规划器
        true_policy_file = f"{map_name.split('.')[0]}_true_goal_{true_goal[0]}_{true_goal[1]}_policy.npz"
        true_planner = PolicyIterationPlanner(map_name, start, true_goal)
        true_planner.policy, true_planner.value_function, true_planner.q_table = load_policy(true_policy_file)
        
        # 加载假目标规划器
        fake_planners = []
        for fake_goal in fake_goals:
            fake_policy_file = f"{map_name.split('.')[0]}_fake_goal_{fake_goal[0]}_{fake_goal[1]}_policy.npz"
            fake_planner = PolicyIterationPlanner(map_name, start, fake_goal)
            fake_planner.policy, fake_planner.value_function, fake_planner.q_table = load_policy(fake_policy_file)
            fake_planners.append(fake_planner)

        # 加载dcstar规划器
        D, dcstar, gi_index = load_states_from_file(f"{problem_name}_D_dcstar_gi.txt")
        dcstar_policy_file = f"{problem_name}_dcstar_{dcstar[0]}_{dcstar[1]}_policy.npz"
        dcstar_planner = dcstar_PolicyIterationPlanner(map_name, start, dcstar)
        dcstar_planner.policy, dcstar_planner.value_function, dcstar_planner.q_table = load_policy(dcstar_policy_file)

        print("Creating E-VDM(g*) planner...")
        vdm_planner = ValueDeceptiveModelPlannerExtended(
            true_planner=true_planner,
            fake_planners=fake_planners,
            dcstar_planner=dcstar_planner,
            delta=0,
            gi_index=gi_index
        )

        # 创建保存目录
        trajectorys_dir = "trajectorys"
        os.makedirs(trajectorys_dir, exist_ok=True)

        print("Planning E-VDM(g*) path...")
        vdm_path = vdm_planner.plan_path(start, dcstar, gi_index)

        # 生成文件名（使用action_model=4表示这是E-VDM(g*)方法）
        base_filename = os.path.join(trajectorys_dir, f"{problem_name}_action4")
        visualize_probability_evolution(vdm_planner, f"{base_filename}_probability.png")
        visualize_path(true_planner.map, vdm_path, start, true_goal, fake_goals, 
                      save_path=f"{base_filename}_trajectory.png")
        save_trajectory_to_file(vdm_path, f"{base_filename}_trajectory.txt")

        print(f"Path length: {len(vdm_path)}")
        print(f"Reached true goal: {tuple(vdm_path[-1]) == tuple(true_goal)}")

if __name__ == "__main__":
    main()
