#这个代码用于计算每一个确定性欺骗马尔可夫决策过程的Dsstar，为下一步的轨迹规划做准备
import numpy as np
import matplotlib.pyplot as plt
from policy_iteration_APF import PolicyIterationPlanner, load_policy
import os
import seaborn as sns
from matplotlib.gridspec import GridSpec
import json

def save_states_to_file(D, dcstar, gi_index, filename):
    """
    将D、dcstar、gi_index保存到文本文件中
    Args:
        D: 欺骗状态集
        dcstar: 临界欺骗状态
        gi_index: 最有欺骗性的假目标的索引
        filename: 保存的文件名
    """
    with open(filename, 'w') as f:
        # 保存D
        f.write("D = [\n")
        # 按y坐标分组保存，使输出更有组织性
        current_y = None
        current_line = []
        for x, y in sorted(D, key=lambda p: (p[1], p[0])):
            if y != current_y:
                if current_line:
                    f.write("    " + ", ".join(f"({x},{y})" for x, y in current_line) + ",\n")
                current_line = []
                current_y = y
            current_line.append((x, y))
        if current_line:
            f.write("    " + ", ".join(f"({x},{y})" for x, y in current_line) + ",\n")
        f.write("]\n\n")
        
        # 保存dcstar
        f.write(f"dcstar = {dcstar}\n\n")
        
        # 保存gi_index
        f.write(f"gi_index = {gi_index}\n\n")

    print(f"States saved to {filename}")

class dcstarSearchPlanner:
    def __init__(self, true_planner, fake_planners):
        """
        初始化dc*搜索规划器
        Args:
            true_planner: 真实目标的规划器
            fake_planners: 假目标规划器的列表
            dsstar_planner: dc*规划器
        """
        self.true_planner = true_planner
        self.fake_planners = fake_planners
        
    def search_dcstar(self):
        """
        计算出最临界欺骗状态、最有欺骗性的假目标及其对应的欺骗状态集
        Returns:
            tuple: (D, dcstar, gi_index)
            D: 所有欺骗状态的集合
            dcstar: 临界欺骗状态
            gi_index: 最有欺骗性的假目标的索引
        """
        # 获取地图尺寸
        height, width = self.true_planner.map.shape
        # 初始化欺骗状态集
        D = []
        max_value = float('-inf')
        dcstar = None
        
        # 遍历所有状态寻找D和dcstar
        for y in range(height):
            for x in range(width):
                if self.true_planner.map[y, x]:
                    continue
                    
                current_state = (x, y)
                total_exp = 0
                
                # 计算真实目标的差异
                true_value = self.true_planner.value_function[y, x] - \
                           self.true_planner.value_function[self.true_planner.start[1], 
                                                         self.true_planner.start[0]]
                true_exp = np.exp(true_value)
                total_exp += true_exp
                
                # 计算假目标的差异
                fake_exps = []
                for fake_planner in self.fake_planners:
                    fake_value = fake_planner.value_function[y, x] - \
                               fake_planner.value_function[fake_planner.start[1], 
                                                         fake_planner.start[0]]
                    fake_exp = np.exp(fake_value)
                    fake_exps.append(fake_exp)
                    total_exp += fake_exp
                
                # 计算概率
                true_prob = true_exp / total_exp
                fake_probs = [fake_exp / total_exp for fake_exp in fake_exps]
                
                # 检查是否为欺骗状态
                deceptive = any(fake_prob >= true_prob for fake_prob in fake_probs)
                
                if deceptive:
                    D.append(current_state)
                    
        # 在欺骗状态集中找出dcstar
        for state in D:
            value = self.true_planner.value_function[state[1], state[0]]
            if value >= max_value:
                max_value = value
                dcstar = state

        # 找出最有欺骗性的假目标gi
        max_value_diff = float('-inf')
        gi_index = 0
        for i, fake_planner in enumerate(self.fake_planners):
            value_diff = fake_planner.value_function[dcstar[1], dcstar[0]] - \
                        fake_planner.value_function[fake_planner.start[1], fake_planner.start[0]]
            if value_diff > max_value_diff:
                max_value_diff = value_diff
                gi_index = i

        return D, dcstar, gi_index

def main():
    # 从文件读取问题配置
    with open('problem_configs.txt', 'r') as f:
        config_content = f.read()
    # 去掉开头的"problem_configs = "部分
    config_content = config_content.replace("problem_configs = ", "")
    # 使用eval将字符串转换为字典（确保输入文件内容可信）
    problem_configs = eval(config_content)

    # 为每个地图生成策略
    for problem_name, config in problem_configs.items():
        map_name = config["map"]
        print(f"\nProcessing problem: {problem_name} with map {map_name}")
        start = config["start"]
        true_goal = config["true_goal"]
        fake_goals = config["fake_goals"]

    
        # 真目标：加载策略、创建规划器
        true_policy_file = f"{map_name.split('.')[0]}_true_goal_{true_goal[0]}_{true_goal[1]}_policy.npz"
        true_planner = PolicyIterationPlanner(map_name, start, true_goal)
        true_planner.policy, true_planner.value_function, true_planner.q_table = load_policy(true_policy_file)
        
        # 假目标：加载策略、创建规划器
        fake_planners = []
        for fake_goal in fake_goals:
            fake_policy_file = f"{map_name.split('.')[0]}_fake_goal_{fake_goal[0]}_{fake_goal[1]}_policy.npz"
            fake_planner = PolicyIterationPlanner(map_name, start, fake_goal)
            fake_planner.policy, fake_planner.value_function, fake_planner.q_table = load_policy(fake_policy_file)
            fake_planners.append(fake_planner)

        # 创建dc*搜索规划器
        print("Creating dc* search planner...")
        dcstar_planner = dcstarSearchPlanner(
            true_planner=true_planner,
            fake_planners=fake_planners
        )
        D, dcstar, gi_index = dcstar_planner.search_dcstar()
        print(f"Deceptive states: {D}")
        print(f"Most critical deceptive state: {dcstar}")
        print(f"Most deceptive fake goal index: {gi_index}")

        # 保存结果到文件
        filename = f"{problem_name}_D_dcstar_gi.txt"
        save_states_to_file(D, dcstar, gi_index, filename)

if __name__ == "__main__":
    main()