import torch
import matplotlib.pyplot as plt
from torchtyping import TensorType
from typing import List
import os
import datetime, re
import numpy as np
import seaborn as sns

def plot_resultses_ds(resultses, batch_size:int,name:List[str],img_name,save_path)->None:
    plt.close()
    # 设置美观的样式
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # 创建图表和坐标轴
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # 使用颜色循环确保一致的颜色分配
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
    
    # 存储所有结果
    avg_resultses = []
    error_bars = []
    
    for results in resultses:
        if results[0][0].shape[1] != 1:
            raise NotImplementedError("Only support 1D data")
        
        # 批处理结果
        results = [[v.view(1) for v in result] for result in results]
        results = np.array(results)
        batched_results = []
        
        for i in range(batch_size, results.shape[1] + 1, batch_size):
            batched_results.append(np.max(results[:, i - batch_size:i], axis=1))
        
        batched_results = np.hstack(batched_results)
        max_till_now = np.maximum.accumulate(batched_results, axis=1)
        
        # 计算平均值和标准差
        avg_results = np.mean(max_till_now, axis=0)
        error_bar = np.std(max_till_now, axis=0)
        # error_bar = np.var(max_till_now, axis=0)
        
        avg_resultses.append(avg_results)
        error_bars.append(error_bar)
    
    # 绘制每个结果集
    for i, (avg_results, error_bar) in enumerate(zip(avg_resultses, error_bars)):
        # 主线条使用实线，更粗
        line, = ax.plot(
            avg_results, 
            label=name[i],
            linewidth=2.5,
            color=color_cycle[i % len(color_cycle)]
        )
        
        # 误差带使用匹配颜色但更浅的填充
        ax.fill_between(
            range(len(avg_results)),
            avg_results - error_bar,
            avg_results + error_bar,
            alpha=0.15,  # 降低透明度避免重叠混淆
            color=line.get_color(),
            edgecolor=None,  # 移除边界线
            linewidth=0      # 无边界线
        )
    
    # 添加图表元素
    ax.set_title(f'Performance Comparison: {img_name}', fontsize=14, pad=20)
    ax.set_xlabel('Batch Iteration', fontsize=12)
    ax.set_ylabel('Cumulative Maximum Value', fontsize=12)

    ax.set_ylim(0, 100) # 固定y的范围
    
    # 优化图例
    ax.legend(
        loc='lower right' if len(avg_resultses) < 5 else 'best',
        frameon=True,
        shadow=True,
        fancybox=True
    )
    
    # 设置网格
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # 调整布局
    plt.tight_layout()
    
    # 确保保存目录存在
    os.makedirs(save_path, exist_ok=True)
    
    # 保存高质量图像
    plt.savefig(
        os.path.join(save_path, f"{img_name}.png"),
        dpi=300,
        bbox_inches='tight'
    )
    plt.close()

def load_pt_file(file_path):
    try:
        data = np.array(torch.load(file_path))
    except:
        temp = torch.load(file_path)
        standard_length = len(temp[1][0])
        baseline_length = len(temp[0][0])
        if baseline_length > standard_length:
            temp[0][0] = temp[0][0][:standard_length]
        elif baseline_length < standard_length:
            sup_length = standard_length - baseline_length
            temp[0][0] = temp[0][0] + temp[0][0][-sup_length:]
        data = np.array(temp)
    
    return data

if __name__ == "__main__":
    use_diverse = False
    wo_pretrain = False

    results_dir = "compare_pretrain_results"

    # name_dir = ["no_pretrain_40", "pretrain_40"]
    name_dir = ["no_pretrain_baseline_pseudo_40", "pretrain_baseline_pseudo_40"]

    DATA_DIRECTORY = [f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{dir_name}/suzuki' for dir_name in name_dir]
    
    BATCH_SIZE = 5

    # METHODS = ['sims_ex', 'expert', 'sci']
    METHODS = ['baseline_pseudo']
        
    SAVE_PATH = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/plot_results/{results_dir}'

    os.makedirs(SAVE_PATH, exist_ok=True)

    partern_files = list(os.listdir(DATA_DIRECTORY[1]))
    
    
    couple_data = {}
    for method in METHODS: 
        couple_data[f"{method}_no_pretrain"] = []
        couple_data[f"{method}_pretrain"] = []
    
    for pt_file in os.listdir(DATA_DIRECTORY[0]):
        
        print("Processing file:", pt_file)
        
        signal = pt_file.split("_raw_results")[0]
        partern = [f for f in partern_files if f.startswith(signal) and f.endswith(".pt")][0]   # pretrain
        
        data1 = load_pt_file(os.path.join(DATA_DIRECTORY[0], pt_file))
        data2 = load_pt_file(os.path.join(DATA_DIRECTORY[1], partern))
        
        for idx, method in enumerate(METHODS):
            couple_data[f"{method}_no_pretrain"].append(data1[idx])
            couple_data[f"{method}_pretrain"].append(data2[idx])

    for method in METHODS:
        method1 = f"{method}_no_pretrain"
        method2 = f"{method}_pretrain"
        data1 = torch.tensor(couple_data[method1]).squeeze(1)
        data2 = torch.tensor(couple_data[method2]).squeeze(1)

        sum_data = torch.stack([data1, data2], axis=0)  # 2, n, 50

        plot_resultses_ds(
            resultses=sum_data,
            batch_size=BATCH_SIZE,
            name=[method1, method2],
            img_name=f"{method}_comparison",
            save_path=SAVE_PATH
        )