import numpy as np
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import PatternFill
from openpyxl.utils import get_column_letter
import matplotlib.pyplot as plt
import os
from policy_iteration_APF import load_policy

def create_value_excel(value_function, map_env, output_file):
    """
    将状态价值导出到Excel文件
    Args:
        value_function: 状态价值函数矩阵
        map_env: 地图环境（用于标记障碍物）
        output_file: 输出的Excel文件名
    """
    # 创建一个新的Excel工作簿
    wb = Workbook()
    ws = wb.active
    ws.title = "State Values"
    
    height, width = map_env.shape
    
    # 设置列宽
    for col in range(width + 1):
        ws.column_dimensions[get_column_letter(col + 1)].width = 12
    
    # 获取价值的范围，用于归一化颜色（排除障碍物）
    masked_values = np.where(map_env, np.nan, value_function)
    valid_values = masked_values[~np.isnan(masked_values)]
    v_min, v_max = np.min(valid_values), np.max(valid_values)
    
    # 写入数据并设置颜色
    for y in range(height):
        for x in range(width):
            cell = ws.cell(row=y + 1, column=x + 1)
            
            if map_env[y, x]:  # 如果是障碍物
                cell.value = ''
                cell.fill = PatternFill(start_color='808080', 
                                      end_color='808080', 
                                      fill_type='solid')
            else:
                value = value_function[y, x]
                cell.value = f'{value:.2f}'
                if value != 0:
                    # 归一化价值并设置颜色（使用红黄绿色谱）
                    norm_value = (value - v_min) / (v_max - v_min)
                    rgb = plt.cm.RdYlGn(norm_value)[:3]
                    hex_color = ''.join([f'{int(255*x):02x}' for x in rgb])
                    cell.fill = PatternFill(start_color=hex_color, 
                                          end_color=hex_color, 
                                          fill_type='solid')
    
    # 保存Excel文件
    wb.save(output_file)
    print(f"状态价值表已保存到 {output_file}")

def create_q_table_excel(q_table, height, width, output_file='q_values.xlsx'):
    """
    将Q值导出到Excel文件，为每个动作创建单独的sheet
    Args:
        q_table: Q值表 (height, width, actions)
        height: 地图高度
        width: 地图宽度
        output_file: 输出的Excel文件名
    """
    # 创建一个新的Excel工作簿
    wb = Workbook()
    
    # 动作名称
    action_names = ['Up', 'UpRight', 'Right', 'DownRight', 
                   'Down', 'DownLeft', 'Left', 'UpLeft']
    
    # 获取所有Q值的范围，用于归一化颜色
    valid_q = q_table[~np.isnan(q_table)]
    q_min, q_max = np.min(valid_q), np.max(valid_q)
    
    # 为每个动作创建sheet并填充数据
    for action_idx, action_name in enumerate(action_names):
        if action_idx == 0:
            ws = wb.active
            ws.title = action_name
        else:
            ws = wb.create_sheet(action_name)
        
        # 设置列宽
        for col in range(width + 1):
            ws.column_dimensions[get_column_letter(col + 1)].width = 12
            
        # 写入数据并设置颜色
        for y in range(height):
            for x in range(width):
                cell = ws.cell(row=y + 1, column=x + 1)
                value = q_table[y, x, action_idx]
                
                if np.isnan(value):
                    cell.value = ''
                    cell.fill = PatternFill(start_color='808080', 
                                          end_color='808080', 
                                          fill_type='solid')
                else:
                    cell.value = f'{value:.2f}'
                    if value != 0:
                        norm_value = (value - q_min) / (q_max - q_min)
                        rgb = plt.cm.RdYlGn(norm_value)[:3]
                        hex_color = ''.join([f'{int(255*x):02x}' for x in rgb])
                        cell.fill = PatternFill(start_color=hex_color, 
                                              end_color=hex_color, 
                                              fill_type='solid')
    
    # 保存Excel文件
    wb.save(output_file)
    print(f"Q值表已保存到 {output_file}")

def process_npz_files():
    """处理文件夹中的所有.npz文件，生成相应的Q表和状态价值Excel文件"""
    print("开始处理.npz文件...")
    
    # 查找所有.npz文件
    npz_files = [f for f in os.listdir('.') if f.endswith('_policy.npz')]
    
    if not npz_files:
        print("未找到任何.npz文件")
        return
        
    for npz_file in npz_files:
        print(f"\n处理文件: {npz_file}")
        
        # 加载策略文件
        policy, value_function, q_table = load_policy(npz_file)
        
        # 加载对应的地图文件
        map_name = npz_file.split('_')[0] + '.map'
        try:
            with open(map_name, 'r') as f:
                lines = f.readlines()
            height = int(lines[1].split()[1])
            width = int(lines[2].split()[1])
            map_env = np.zeros((height, width), dtype=bool)
            for i, line in enumerate(lines[4:4+height]):
                for j, char in enumerate(line.strip()):
                    map_env[i, j] = (char == 'T')
        except Exception as e:
            print(f"警告: 无法加载地图文件 {map_name}, 错误: {e}")
            continue
        
        if q_table is not None:
            # 创建Q表Excel文件
            output_q_excel = npz_file.replace('.npz', '_q_values.xlsx')
            create_q_table_excel(q_table, height, width, output_q_excel)
        
        # 创建状态价值Excel文件
        output_v_excel = npz_file.replace('.npz', '_state_values.xlsx')
        create_value_excel(value_function, map_env, output_v_excel)

if __name__ == "__main__":
    process_npz_files()