# 这个代码用于计算所有真假目标状态的价值，包括其动作Q值
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, deque

def save_policy(policy, value_function, filename, q_table=None):
    """Save policy, value function and Q-table to a file"""
    if q_table is not None:
        np.savez(filename, 
                 policy=policy, 
                 value_function=value_function,
                 q_table=q_table)
    else:
        np.savez(filename, 
                 policy=policy, 
                 value_function=value_function)

def load_policy(filename):
    """Load policy, value function and Q-table from a file"""
    data = np.load(filename, allow_pickle=True)
    policy = data['policy']
    value_function = data['value_function']
    q_table = data['q_table'] if 'q_table' in data else None
    return policy, value_function, q_table

class PolicyIterationPlanner:
    def __init__(self, map_file, start, goal, gamma=0.99, theta=1e-3):
        """
        初始化基于策略迭代的路径规划器
        Args:
            map_file: 地图文件路径
            start: 起点坐标 (x, y)
            goal: 终点坐标 (x, y)
            gamma: 折扣因子
            theta: 策略评估的收敛阈值
        """
        # 基础设置
        self.gamma = gamma  # 折扣因子
        self.theta = theta  # 收敛阈值
        
        # 动作空间：8个方向及其代价
        self.actions = np.array([
            [0, 1], [1, 1], [1, 0], [1, -1],
            [0, -1], [-1, -1], [-1, 0], [-1, 1]
        ])
        self.costs = np.array([1.0, np.sqrt(2), 1.0, np.sqrt(2),
                             1.0, np.sqrt(2), 1.0, np.sqrt(2)])
        
        # 加载地图和设置起终点
        self.load_map(map_file)
        self.start = np.array(start)
        self.goal = np.array(goal)
        
        # 初始化状态空间
        self.initialize_state_space()
        
        # 初始化策略和价值函数
        self.initialize_policy()
        self.value_function = np.zeros((self.height, self.width))
        
        # 初始化Q表
        self.q_table = np.zeros((self.height, self.width, len(self.actions)))

    def load_map(self, map_file):
        """加载地图文件"""
        with open(map_file, 'r') as f:
            lines = f.readlines()
        self.height = int(lines[1].split()[1])
        self.width = int(lines[2].split()[1])
        self.map = np.zeros((self.height, self.width), dtype=bool)
        for i, line in enumerate(lines[4:4+self.height]):
            for j, char in enumerate(line.strip()):
                self.map[i, j] = (char == 'T')

    def initialize_state_space(self):
        """初始化状态空间，设置奖励函数"""
        # 默认所有状态的奖励设为0（后面会根据实际移动更新）
        self.rewards = np.zeros((self.height, self.width))
        
        # 设置障碍物和边界为不可达
        self.rewards[self.map] = -float('inf')  # 障碍物
        self.rewards[0, :] = -float('inf')  # 上边界
        self.rewards[-1, :] = -float('inf')  # 下边界
        self.rewards[:, 0] = -float('inf')  # 左边界
        self.rewards[:, -1] = -float('inf')  # 右边界
        
        # # 设置目标点奖励 (注意：self.goal是(x,y)形式)
        # goal_x, goal_y = self.goal
        # self.rewards[goal_y, goal_x] = 10  # 目标奖励

    def get_action_reward(self, action):
        """获取动作对应的奖励"""
        # 判断是否是对角线移动
        if abs(action[0]) == 1 and abs(action[1]) == 1:
            return -np.sqrt(2)  # 对角线移动
        return -1.0  # 直行移动

    def is_valid_state(self, state):
        """检查状态是否有效 (state是(x,y)形式)"""
        x, y = state
        return (0 <= x < self.width and 0 <= y < self.height 
                and not self.map[y, x])

    def get_next_state(self, state, action):
        """
        获取执行动作后的下一个状态
        state: (x,y)形式的当前状态
        action: [dx,dy]形式的动作
        returns: (x,y)形式的下一个状态
        """
        # 如果当前状态是目标点，保持不变（终止状态）
        if state == tuple(self.goal):
            return state
            
        next_state = np.array([state[0] + action[0], state[1] + action[1]])
        if self.is_valid_state(next_state):
            return tuple(next_state)
        print(f"Invalid action {action} from state {state} to {next_state}")
        return state  # 如果不可行则保持原位置

    def is_valid_action(self, state, action):
        """
        检查在给定状态下某个动作是否有效（不会导致碰撞或到达边界）
        Args:
            state: (x,y)形式的当前状态
            action: [dx,dy]形式的动作
        Returns:
            bool: 动作是否有效
        """
        x, y = state
        next_x = x + action[0]
        next_y = y + action[1]
        
        # 检查是否会碰到边界
        if (next_x <= 0 or next_x >= self.width-1 or 
            next_y <= 0 or next_y >= self.height-1):
            return False
        # 检查是否会碰到障碍物
        if self.map[next_y, next_x]:
            return False
        # 检查对角线移动时是否会碰到障碍物
        if abs(action[0]) == 1 and abs(action[1]) == 1:
            if (self.map[y, next_x] and self.map[next_y, x]):
                return False
        return True

    def get_valid_actions(self, state):
        """
        获取在给定状态下所有有效的动作
        Args:
            state: 当前状态 (x, y)
        Returns:
            list: 有效动作的索引列表
        """
        # 如果是目标点，返回空列表（目标点无有效动作）
        if state == tuple(self.goal):
            return []
            
        valid_actions = []
        for idx, action in enumerate(self.actions):
            if self.is_valid_action(state, action):
                valid_actions.append(idx)
        return valid_actions

    def get_transition_reward(self, state, action, next_state):
        """
        获取状态转移的奖励
        所有状态都是(x,y)形式
        """
        # 如果当前已经在目标点，则无奖励（吸收态）
        if state == tuple(self.goal):
            return 0
            
        # 获取动作奖励（移动代价）
        action_reward = self.get_action_reward(action)
        
        # 如果到达目标点，额外获得目标奖励
        if (next_state[0] == self.goal[0] and next_state[1] == self.goal[1]):
            return action_reward + 100
            
        return action_reward

    def initialize_policy(self):
        """
        初始化基于APF的确定性策略
        策略表示为每个状态下的最优动作索引
        """
        # 策略矩阵：(height, width)，存储最优动作的索引
        self.policy = np.zeros((self.height, self.width), dtype=int)
        
        # 对每个非障碍物状态初始化策略
        for x in range(self.width):
            for y in range(self.height):
                if self.map[y, x] or y == 0 or y == self.height-1 or x == 0 or x == self.width-1:
                    continue
                    
                # 计算到目标的方向向量（目标和当前位置都是(x,y)形式）
                to_goal = self.goal - np.array([x, y])
                if np.all(to_goal == 0):  # 如果是目标点
                    self.policy[y, x] = 0  # 默认向上
                    continue
                
                # 计算每个动作与目标方向的夹角
                goal_dir = to_goal / np.linalg.norm(to_goal)
                best_score = -float('inf')
                best_action_idx = 0
                
                for idx, action in enumerate(self.actions):
                    if np.all(action == 0):
                        continue
                    action_dir = action / np.linalg.norm(action)
                    # 计算cos夹角，范围[-1, 1]，越大说明方向越接近
                    score = np.dot(goal_dir, action_dir)
                    if score > best_score:
                        best_score = score
                        best_action_idx = idx
                
                self.policy[y, x] = best_action_idx

    def policy_evaluation(self):
        """策略评估：计算当前策略下的价值函数和Q表"""
        while True:
            delta = 0
            for x in range(self.width):
                for y in range(self.height):
                    if self.map[y, x] or y == 0 or y == self.height-1 or x == 0 or x == self.width-1:
                        continue
                    
                    # 如果是目标点，设置固定价值和Q值
                    if (x, y) == tuple(self.goal):
                        self.value_function[y, x] = 100  # 目标点状态价值固定为终止奖励
                        # 设置目标点所有动作的Q值为终止奖励
                        self.q_table[y, x, :] = 100
                        continue
                    
                    v = self.value_function[y, x]
                    state = (x, y)
                    
                    # 更新Q值表
                    for action_idx, action in enumerate(self.actions):
                        if not self.is_valid_action(state, action):
                            self.q_table[y, x, action_idx] = -float('inf')
                            continue
                            
                        next_state = self.get_next_state(state, action)
                        reward = self.get_transition_reward(state, action, next_state)
                        
                        # 如果下一个状态是目标点，只考虑即时奖励和固定的终止价值
                        if next_state == tuple(self.goal):
                            self.q_table[y, x, action_idx] = reward
                        else:
                            self.q_table[y, x, action_idx] = reward + self.gamma * \
                                self.value_function[next_state[1], next_state[0]]
                    
                    # 获取当前策略下的动作
                    action_idx = self.policy[y, x]
                    action = self.actions[action_idx]
                    
                    # 检查动作是否有效
                    if not self.is_valid_action(state, action):
                        valid_actions = self.get_valid_actions(state)
                        if valid_actions:
                            action_idx = valid_actions[0]
                            self.policy[y, x] = action_idx
                            action = self.actions[action_idx]
                        else:
                            continue
                    
                    next_state = self.get_next_state(state, action)
                    reward = self.get_transition_reward(state, action, next_state)
                    
                    # 计算新的价值（对目标点邻近状态特殊处理）
                    if next_state == tuple(self.goal):
                        new_v = reward  # 直接使用到达目标的奖励
                    else:
                        new_v = reward + self.gamma * self.value_function[next_state[1], next_state[0]]
                    
                    self.value_function[y, x] = new_v
                    delta = max(delta, abs(v - new_v))
            
            if delta < self.theta:
                break

    def policy_improvement(self):
        """策略改进：根据当前的价值函数更新策略"""
        policy_stable = True
        
        for x in range(self.width):
            for y in range(self.height):
                if self.map[y, x] or y == 0 or y == self.height-1 or x == 0 or x == self.width-1:
                    continue
                
                old_action = self.policy[y, x]
                state = (x, y)  # 使用(x,y)形式
                
                # 获取有效的动作
                valid_actions = self.get_valid_actions(state)
                if not valid_actions:
                    continue
                
                # 计算每个有效动作的价值
                action_values = np.zeros(len(self.actions)) - float('inf')
                for a_idx in valid_actions:
                    action = self.actions[a_idx]
                    next_state = self.get_next_state(state, action)
                    reward = self.get_transition_reward(state, action, next_state)
                    # 注意：value_function的索引是[y,x]
                    action_values[a_idx] = reward + self.gamma * \
                        self.value_function[next_state[1], next_state[0]]
                
                # 选择最佳有效动作
                new_action = np.argmax(action_values)
                self.policy[y, x] = new_action
                
                # 检查策略是否改变
                if old_action != new_action:
                    policy_stable = False
        
        return policy_stable

    def policy_iteration(self):
        """
        执行策略迭代算法
        交替进行策略评估和策略改进，直到策略稳定
        """
        iteration = 0
        while True:
            iteration += 1
            print(f"Iteration {iteration}")
            
            # 策略评估
            self.policy_evaluation()
            
            # 策略改进
            policy_stable = self.policy_improvement()
            
            # 如果策略稳定，则停止迭代
            if policy_stable:
                break

if __name__ == "__main__":
    # 从文件读取问题配置
    with open('problem_configs.txt', 'r') as f:
        config_content = f.read()
    # 去掉开头的"problem_configs = "部分
    config_content = config_content.replace("problem_configs = ", "")
    # 使用eval将字符串转换为字典（确保输入文件内容可信）
    problem_configs = eval(config_content)

    # 为每个地图生成策略
    for problem_name, config in problem_configs.items():
        map_name = config["map"]
        print(f"\nProcessing problem: {problem_name} with map {map_name}")
        start = config["start"]
        true_goal = config["true_goal"]
        
        # 为真目标生成策略
        print(f"Generating policy for true goal {true_goal}...")
        planner = PolicyIterationPlanner(
            map_file=map_name,
            start=start,
            goal=true_goal,
            gamma=0.99,
            theta=1e-3
        )
        planner.policy_iteration()
        
        # 保存真目标的策略和Q表
        true_goal_filename = f"{map_name.split('.')[0]}_true_goal_{true_goal[0]}_{true_goal[1]}_policy.npz"
        save_policy(
            policy=planner.policy,
            value_function=planner.value_function,
            filename=true_goal_filename,
            q_table=planner.q_table
        )
        print(f"Saved true goal policy to {true_goal_filename}")
        
        # 为每个假目标生成策略
        for fake_goal in config["fake_goals"]:
            print(f"Generating policy for fake goal {fake_goal}...")
            planner = PolicyIterationPlanner(
                map_file=map_name,
                start=start,
                goal=fake_goal,
                gamma=0.99,
                theta=1e-3
            )
            planner.policy_iteration()
            
            # 保存假目标的策略和Q表
            fake_goal_filename = f"{map_name.split('.')[0]}_fake_goal_{fake_goal[0]}_{fake_goal[1]}_policy.npz"
            save_policy(
                policy=planner.policy,
                value_function=planner.value_function,
                filename=fake_goal_filename,
                q_table=planner.q_table
            )
            print(f"Saved fake goal policy to {fake_goal_filename}")

    print("\nAll policies generated and saved successfully.")