import os
import numpy as np
import matplotlib.pyplot as plt
from policy_iteration_APF import load_policy

def load_configs():
    """加载问题配置"""
    with open('problem_configs.txt', 'r', encoding='utf-8') as f:
        config_content = f.read()
    config_content = config_content.replace("problem_configs = ", "")
    return eval(config_content)

def get_trajectory_length(trajectory_file):
    """获取轨迹长度"""
    with open(trajectory_file, 'r') as f:
        lines = f.readlines()
        # 减去标题行
        return len(lines) - 1

def calculate_trajectory_cost_ratios():
    """计算每个问题的轨迹长度比值"""
    problem_configs = load_configs()
    trajectorys_dir = "trajectorys"
    output_dir = os.path.join("trajectory_cost")
    os.makedirs(output_dir, exist_ok=True)

    for problem_name in problem_configs.keys():
        # 获取基准轨迹（action_model=0）的长度
        base_trajectory_file = os.path.join(trajectorys_dir, f"{problem_name}_action0_trajectory.txt")
        if not os.path.exists(base_trajectory_file):
            print(f"Warning: Base trajectory file not found for {problem_name}")
            continue
        base_length = get_trajectory_length(base_trajectory_file)

        # 为每个action model计算比值
        output_file = os.path.join(output_dir, f"{problem_name}_trajectorycost.txt")
        with open(output_file, 'w') as f:
            f.write("Action Model\tLength Ratio\n")
            for action_model in range(1, 4):  # action models 1-3
                trajectory_file = os.path.join(trajectorys_dir, f"{problem_name}_action{action_model}_trajectory.txt")
                if os.path.exists(trajectory_file):
                    length = get_trajectory_length(trajectory_file)
                    ratio = length / base_length
                    f.write(f"{action_model}\t{ratio:.4f}\n")
                else:
                    print(f"Warning: Trajectory file not found for {problem_name}, action_model {action_model}")

def calculate_state_value_probabilities(state, true_planner, fake_planners, start_state):
    """计算状态的目标概率"""
    state = tuple(map(int, state))
    goal_differences = {}
    total_exp = 0
    
    # 计算真实目标的差异
    state_value_diff = (true_planner.value_function[state[1], state[0]] - 
                       true_planner.value_function[start_state[1], start_state[0]])
    goal_differences['true'] = np.exp(state_value_diff)
    total_exp += goal_differences['true']
    
    # 计算假目标的差异
    for i, fake_planner in enumerate(fake_planners):
        state_value_diff = (fake_planner.value_function[state[1], state[0]] -
                          fake_planner.value_function[start_state[1], start_state[0]])
        goal_differences[f'fake_{i}'] = np.exp(state_value_diff)
        total_exp += goal_differences[f'fake_{i}']
        
    # 计算概率
    probabilities = {goal: diff / total_exp for goal, diff in goal_differences.items()}
    
    # 返回所有概率
    true_prob = probabilities['true']
    fake_probs = [probabilities[f'fake_{i}'] for i in range(len(fake_planners))]
    return true_prob, fake_probs

def calculate_trajectory_probabilities():
    """计算轨迹上每个点的真目标概率，同时计算steps_after_LDS"""
    from policy_iteration_APF import PolicyIterationPlanner

    problem_configs = load_configs()
    trajectorys_dir = "trajectorys"
    prob_dir = os.path.join("probability")
    lds_dir = os.path.join("steps_after_LDS")
    os.makedirs(prob_dir, exist_ok=True)
    os.makedirs(lds_dir, exist_ok=True)

    for problem_name, config in problem_configs.items():
        map_name = config["map"]
        start = config["start"]
        true_goal = config["true_goal"]
        fake_goals = config["fake_goals"]

        # 加载真目标规划器
        true_policy_file = f"{map_name.split('.')[0]}_true_goal_{true_goal[0]}_{true_goal[1]}_policy.npz"
        true_planner = PolicyIterationPlanner(map_name, start, true_goal)
        true_planner.policy, true_planner.value_function, true_planner.q_table = load_policy(true_policy_file)

        # 加载假目标规划器
        fake_planners = []
        for fake_goal in fake_goals:
            fake_policy_file = f"{map_name.split('.')[0]}_fake_goal_{fake_goal[0]}_{fake_goal[1]}_policy.npz"
            fake_planner = PolicyIterationPlanner(map_name, start, fake_goal)
            fake_planner.policy, fake_planner.value_function, fake_planner.q_table = load_policy(fake_policy_file)
            fake_planners.append(fake_planner)

        # 创建steps_after_LDS输出文件
        lds_file = os.path.join(lds_dir, f"{problem_name}_steps_after_LDS.txt")
        with open(lds_file, 'w') as lds_f:
            lds_f.write("Action Model\tSteps After Last Deceptive State\n")

        # 处理每个action model的轨迹
        for action_model in range(4):  # action models 0-3
            trajectory_file = os.path.join(trajectorys_dir, f"{problem_name}_action{action_model}_trajectory.txt")
            if not os.path.exists(trajectory_file):
                continue

            # 读取轨迹并计算概率
            with open(trajectory_file, 'r') as f:
                lines = f.readlines()[1:]  # 跳过标题行
                states = []
                true_probs = []
                last_deceptive_idx = -1

                for i, line in enumerate(lines):
                    x, y = map(int, line.strip().split(','))
                    true_prob, fake_probs = calculate_state_value_probabilities((x, y), true_planner, fake_planners, start)
                    states.append((x, y))
                    true_probs.append(true_prob)

                    # 检查是否是欺骗状态
                    if any(fake_prob >= true_prob for fake_prob in fake_probs):
                        last_deceptive_idx = i

            # 保存概率数据
            prob_file = os.path.join(prob_dir, f"{problem_name}_action{action_model}_probability.txt")
            with open(prob_file, 'w') as out:
                out.write("# State (x, y)\tTrue Goal Probability\n")
                for (x, y), prob in zip(states, true_probs):
                    out.write(f"{x}, {y}\t{prob:.4f}\n")

            # 保存steps_after_LDS数据
            steps_after = len(states) - 1 - last_deceptive_idx
            with open(lds_file, 'a') as lds_f:
                lds_f.write(f"{action_model}\t{steps_after}\n")

def calculate_steps_after_last_deceptive_state():
    """计算最后一个欺骗状态之后的状态数量"""
    from policy_iteration_APF import PolicyIterationPlanner
    
    problem_configs = load_configs()
    trajectorys_dir = "trajectorys"
    output_dir = os.path.join("steps_after_LDS")
    os.makedirs(output_dir, exist_ok=True)

    for problem_name, config in problem_configs.items():
        map_name = config["map"]
        start = config["start"]
        true_goal = config["true_goal"]
        fake_goals = config["fake_goals"]

        # 加载规划器
        true_planner = PolicyIterationPlanner(map_name, start, true_goal)
        true_policy_file = f"{map_name.split('.')[0]}_true_goal_{true_goal[0]}_{true_goal[1]}_policy.npz"
        true_planner.policy, true_planner.value_function, true_planner.q_table = load_policy(true_policy_file)

        fake_planners = []
        for fake_goal in fake_goals:
            fake_planner = PolicyIterationPlanner(map_name, start, fake_goal)
            fake_policy_file = f"{map_name.split('.')[0]}_fake_goal_{fake_goal[0]}_{fake_goal[1]}_policy.npz"
            fake_planner.policy, fake_planner.value_function, fake_planner.q_table = load_policy(fake_policy_file)
            fake_planners.append(fake_planner)

        output_file = os.path.join(output_dir, f"{problem_name}_steps_after_LDS.txt")
        with open(output_file, 'w') as f:
            f.write("Action Model\tSteps After Last Deceptive State\n")
            
            for action_model in range(4):
                trajectory_file = os.path.join(trajectorys_dir, f"{problem_name}_action{action_model}_trajectory.txt")
                if not os.path.exists(trajectory_file):
                    continue

                # 读取轨迹
                with open(trajectory_file, 'r') as traj_f:
                    states = [tuple(map(int, line.strip().split(','))) 
                             for line in traj_f.readlines()[1:]]

                # 找到最后一个欺骗状态
                last_deceptive_idx = -1
                for i, state in enumerate(states):
                    true_prob, fake_probs = calculate_state_value_probabilities(
                        state, true_planner, fake_planners, start)
                    if any(fake_prob >= true_prob for fake_prob in fake_probs):
                        last_deceptive_idx = i

                steps_after = len(states) - 1 - last_deceptive_idx
                f.write(f"{action_model}\t{steps_after}\n")

def set_plot_style():
    """设置全局绘图样式"""
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['mathtext.fontset'] = 'stix'  # 数学公式字体
    plt.rcParams['xtick.major.size'] = 0  # 删除刻度线
    plt.rcParams['ytick.major.size'] = 0  # 删除刻度线
    plt.rcParams['axes.spines.top'] = False  # 移除上边框
    plt.rcParams['axes.spines.right'] = False  # 移除右边框
    plt.rcParams['axes.spines.left'] = False  # 保留左边框
    plt.rcParams['axes.spines.bottom'] = False  # 保留下边框
    plt.rcParams['axes.facecolor'] = '#f0f0f0'  # 设置非常淡的灰色背景

def get_model_name(action_model):
    """获取模型的具体名称"""
    model_names = {
        0: 'Honest',
        1: 'AM',
        2: 'A-VDM',
        3: 'E-VDM(σ=1)',
    }
    return model_names.get(action_model, f'Model {action_model}')

def create_combined_probability_evolution(all_problem_data):
    """创建所有问题的组合概率演变图"""
    set_plot_style()
    plt.figure(figsize=(8, 6))
    percentages = np.arange(0, 101, 10)
    line_styles = ['-', '--', ':', '-.', '-']  # 不同的线型
    markers = ['o', 's', '^', 'D', 'v']  # 不同的标记

    # 为每个 action model 创建一条线（平均所有问题）
    for action_model in range(4):
        # 收集所有问题的这个 action model 的数据
        all_probs = []
        for problem_data in all_problem_data.values():
            if action_model in problem_data:
                all_probs.append(problem_data[action_model])
        
        if all_probs:
            # 计算平均值
            mean_probs = np.mean(all_probs, axis=0)
            # 计算标准差
            std_probs = np.std(all_probs, axis=0)
            
            # 使用标准的matplotlib颜色方案
            colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']  # 蓝、橙、绿、红
            plt.plot(percentages, mean_probs,
                    label=get_model_name(action_model),
                    linestyle=line_styles[action_model],
                    marker=markers[action_model],
                    color=colors[action_model],
                    linewidth=2,
                    markersize=8)
            
            # 添加误差区域
            plt.fill_between(percentages,
                           mean_probs - std_probs,
                           mean_probs + std_probs,
                           alpha=0.2)

    plt.xlabel('Trajectory Progress (%)', fontsize=20)
    plt.ylabel('Probability of real goal state', fontsize=20)
    # plt.title('Combined Problems - True Goal Probability Evolution', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=20)
    plt.tick_params(axis='both', which='major', labelsize=20)
    
    # 设置坐标轴范围
    plt.ylim(-0.05, 1.05)
    plt.xlim(-2, 102)
    
    # 保存图表为矢量图
    plt.savefig(os.path.join("Deceptiveness", "combined_probability_evolution.pdf"),
                bbox_inches='tight', format='pdf')
    plt.savefig(os.path.join("Deceptiveness", "combined_probability_evolution.png"),
                bbox_inches='tight', dpi=300)
    plt.close()

def create_trajectory_cost_boxplot():
    """创建trajectory cost比值的箱线图"""
    set_plot_style()
    problem_configs = load_configs()
    trajectorys_dir = "trajectory_cost"
    cost_data = {i: [] for i in range(1, 4)}  # 存储每个action model的所有比值

    # 收集所有问题的数据
    for problem_name in problem_configs.keys():
        cost_file = os.path.join(trajectorys_dir, f"{problem_name}_trajectorycost.txt")
        if os.path.exists(cost_file):
            with open(cost_file, 'r') as f:
                lines = f.readlines()[1:]  # 跳过标题行
                for line in lines:
                    action_model, ratio = map(float, line.strip().split('\t'))
                    cost_data[int(action_model)].append(ratio)

    # 创建箱线图
    plt.figure(figsize=(8, 6))
    box_data = [cost_data[i] for i in range(1, 4)]
    
    # 设置图形的背景色
    ax = plt.gca()
    ax.set_facecolor('#f0f0f0')
    
    # 计算每个组的平均值
    means = [np.mean(data) for data in box_data]
    
    # 创建箱线图，设置中位数线为黑色，whis='range'表示不显示异常值
    bp = plt.boxplot(box_data, patch_artist=True, medianprops=dict(color="black", linewidth=1.5), whis=[0, 100])
    
    # 添加平均值线（虚线）
    for i, mean in enumerate(means, 1):
        plt.hlines(y=mean, xmin=i-0.3, xmax=i+0.3, color='black', linestyle='--', linewidth=1.5)
    
    # 设置箱体颜色，和折线图保持一致
    colors = ['#ff7f0e', '#2ca02c', '#d62728']  # 从AM开始：橙、绿、红
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # 添加图例
    plt.plot([], [], color='black', linestyle='-', label='Median')
    plt.plot([], [], color='black', linestyle='--', label='Mean')

    # 设置图表属性
    plt.xlabel('Models', fontsize=20)
    plt.ylabel('Trajectory cost ratio', fontsize=20)
    # plt.title('Trajectory Length Ratio Distribution by Action Model')
    plt.grid(True, alpha=0.3)
    plt.xticks(range(1, 4), [get_model_name(i) for i in range(1, 4)], fontsize=20)
    plt.tick_params(axis='y', which='major', labelsize=20)

    # 保存图表为矢量图
    plt.savefig(os.path.join("trajectory_cost", "trajectory_cost_boxplot.pdf"),
                bbox_inches='tight', format='pdf')
    plt.savefig(os.path.join("trajectory_cost", "trajectory_cost_boxplot.png"),
                bbox_inches='tight', dpi=300)
    plt.close()

def create_steps_after_lds_boxplot():
    """创建steps after LDS的箱线图"""
    set_plot_style()
    problem_configs = load_configs()
    trajectorys_dir = "steps_after_LDS"
    steps_data = {i: [] for i in range(4)}  # 存储每个action model的所有步数

    # 收集所有问题的数据
    for problem_name in problem_configs.keys():
        steps_file = os.path.join(trajectorys_dir, f"{problem_name}_steps_after_LDS.txt")
        if os.path.exists(steps_file):
            with open(steps_file, 'r') as f:
                lines = f.readlines()[1:]  # 跳过标题行
                for line in lines:
                    action_model, steps = map(float, line.strip().split('\t'))
                    steps_data[int(action_model)].append(steps)

    # 创建箱线图
    plt.figure(figsize=(8, 6))
    box_data = [steps_data[i] for i in range(4)]
    
    # 计算每个组的平均值
    means = [np.mean(data) for data in box_data]
    
    # 创建箱线图，设置中位数线为黑色，whis='range'表示不显示异常值
    bp = plt.boxplot(box_data, patch_artist=True, medianprops=dict(color="black", linewidth=1.5), whis=[0, 100])

    # 添加平均值线（虚线）
    for i, mean in enumerate(means, 1):
        plt.hlines(y=mean, xmin=i-0.3, xmax=i+0.3, color='black', linestyle='--', linewidth=1.5)
    
    # 设置箱体颜色，和折线图保持一致
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']  # 蓝、橙、绿、红
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # 添加图例
    plt.plot([], [], color='black', linestyle='-', label='Median')
    plt.plot([], [], color='black', linestyle='--', label='Mean')

    # 设置图表属性
    plt.xlabel('Models', fontsize=20)
    plt.ylabel('Steps after LDS', fontsize=20)
    # plt.title('Steps After LDS Distribution by Action Model')
    plt.grid(True, alpha=0.3)
    plt.xticks(range(1, 5), [get_model_name(i) for i in range(4)], fontsize=20)
    plt.tick_params(axis='y', which='major', labelsize=20)

    # 保存图表为矢量图
    plt.savefig(os.path.join("steps_after_LDS", "steps_after_lds_boxplot.pdf"),
                bbox_inches='tight', format='pdf')
    plt.savefig(os.path.join("steps_after_LDS", "steps_after_lds_boxplot.png"),
                bbox_inches='tight', dpi=300)
    plt.close()

def analyze_probability_evolution():
    """从probability文件夹读取数据并生成组合概率演变图"""
    set_plot_style()
    problem_configs = load_configs()
    prob_dir = "probability"
    output_dir = os.path.join("Deceptiveness")
    os.makedirs(output_dir, exist_ok=True)
    
    # 用于存储所有问题的数据
    all_problem_data = {}

    for problem_name in problem_configs.keys():
        all_problem_data[problem_name] = {}
        
        for action_model in range(4):
            prob_file = os.path.join(prob_dir, f"{problem_name}_action{action_model}_probability.txt")
            if not os.path.exists(prob_file):
                continue

            # 读取概率数据
            with open(prob_file, 'r') as f:
                lines = f.readlines()[1:]  # 跳过标题行
                true_probs = [float(line.strip().split('\t')[1]) for line in lines]

            # 计算每个百分比点的平均概率
            percentages = np.arange(0, 101, 10)
            prob_values = []
            for percentage in percentages:
                if percentage == 100:
                    idx = len(true_probs) - 1
                else:
                    idx = int(percentage / 100 * (len(true_probs) - 1))
                
                # 计算窗口内的平均概率
                start_idx = max(0, idx - 2)
                end_idx = min(len(true_probs), idx + 3)
                prob_values.append(np.mean(true_probs[start_idx:end_idx]))

            all_problem_data[problem_name][action_model] = prob_values
    
    # 创建组合图表
    create_combined_probability_evolution(all_problem_data)

def main():
    """主函数"""
    # 确保输出目录存在
    os.makedirs("trajectory_cost", exist_ok=True)
    os.makedirs("probability", exist_ok=True)
    os.makedirs("steps_after_LDS", exist_ok=True)
    os.makedirs("Deceptiveness", exist_ok=True)

    print("Calculating trajectory cost ratios...")
    calculate_trajectory_cost_ratios()

    print("Creating trajectory cost boxplot...")
    create_trajectory_cost_boxplot()

    print("Calculating trajectory probabilities and steps after LDS...")
    calculate_trajectory_probabilities()

    print("Creating steps after LDS boxplot...")
    create_steps_after_lds_boxplot()

    print("Creating combined probability evolution plot...")
    analyze_probability_evolution()

    print("Analysis complete!")

if __name__ == "__main__":
    main()
