import torch
import matplotlib.pyplot as plt
from typing import List
import os
import datetime, re
import numpy as np

def plot_resultses_ds(resultses, batch_size:int,name:List[str],img_name,save_path, DATA_NAME=None, SAVE_NPY=True)->None:
    plt.close()
    # 设置美观的样式
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # 创建图表和坐标轴
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # 使用颜色循环确保一致的颜色分配
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
    
    # 存储所有结果
    avg_resultses = []
    error_bars = []
    
    for idx, results in enumerate(resultses):
        if results[0][0].shape[1] != 1:
            raise NotImplementedError("Only support 1D data")
        
        # 批处理结果
        results = [[v.view(1) for v in result] for result in results]
        results = np.array(results)
        batched_results = []
        
        for i in range(batch_size, results.shape[1] + 1, batch_size):
            batched_results.append(np.max(results[:, i - batch_size:i], axis=1))
        
        batched_results = np.hstack(batched_results)
        max_till_now = np.maximum.accumulate(batched_results, axis=1)
        
        # 计算平均值和标准差
        avg_results = np.mean(max_till_now, axis=0)
        error_bar = np.std(max_till_now, axis=0)
        
        if SAVE_NPY:
            os.makedirs("pre_dh_mean_std_save/baseline", exist_ok=True)
            np.save(f"pre_dh_mean_std_save/baseline/baseline_{DATA_NAME}_{name[idx]}_mean.npy", avg_results)
            np.save(f"pre_dh_mean_std_save/baseline/baseline_{DATA_NAME}_{name[idx]}_std.npy", error_bar)
        
        avg_resultses.append(avg_results)
        error_bars.append(error_bar)
        
    # 绘制每个结果集
    for i, (avg_results, error_bar) in enumerate(zip(avg_resultses, error_bars)):
        # 主线条使用实线，更粗
        line, = ax.plot(
            avg_results, 
            label=name[i],
            linewidth=2.5,
            color=color_cycle[i % len(color_cycle)]
        )
        
        # 误差带使用匹配颜色但更浅的填充
        ax.fill_between(
            range(len(avg_results)),
            avg_results - error_bar,
            avg_results + error_bar,
            alpha=0.15,  # 降低透明度避免重叠混淆
            color=line.get_color(),
            edgecolor=None,  # 移除边界线
            linewidth=0      # 无边界线
        )
    
    # 添加图表元素
    ax.set_title(f'Performance Comparison: {img_name}', fontsize=14, pad=20)
    ax.set_xlabel('Batch Iteration', fontsize=12)
    ax.set_ylabel('Cumulative Maximum Value', fontsize=12)

    ax.set_ylim(0, 100) # 固定y的范围
    
    # 优化图例
    ax.legend(
        loc='lower right' if len(avg_resultses) < 5 else 'best',
        frameon=True,
        shadow=True,
        fancybox=True
    )
    
    # 设置网格
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # 调整布局
    plt.tight_layout()
    
    # 确保保存目录存在
    os.makedirs(save_path, exist_ok=True)
    
    # 保存高质量图像
    plt.savefig(
        os.path.join(save_path, f"{img_name}.png"),
        dpi=300,
        bbox_inches='tight'
    )
    plt.close()

if __name__ == "__main__":
    results_dir = f"pre_dh_diff_batch_baseline_results"
    
    # DATA_NAME = "suzuki"
    # DATA_NAME = "arylation"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv"
    # DATA_NAME = "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv"
    DATA_NAME = "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"
    
    SAVE_NPY = True
    
    partition_maps = {
        "suzuki": 50,
        "arylation": 34,
        "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv": 7,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": 7,
        "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv": 7,
        "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv": 7,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": 7
    }
    BATCH_SIZE_MAPS = {
        "suzuki": 5,
        "arylation": 3,
        "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv": 1,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": 1,
        "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv": 1,
        "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv": 1,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": 1
    }

    NUM_INIT_SAMPLE = partition_maps[DATA_NAME]

    name_dir = f"exp_40_init_{NUM_INIT_SAMPLE}_diverse_pre_dh"

    DATA_DIRECTORY = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/{DATA_NAME}'

    BATCH_SIZE = BATCH_SIZE_MAPS[DATA_NAME]

    METHODS = ['baseline']
        
    SAVE_PATH = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/plot_results/{results_dir}'

    SAVE_PATH = os.path.join(SAVE_PATH, name_dir, DATA_NAME)
    # SAVE_PATH = os.path.join(SAVE_PATH, "temp_test")

    os.makedirs(SAVE_PATH, exist_ok=True)

    grouped_files = {}
    pattern = re.compile(r'dh-(\d+)')

    print(f"Scanning for .pt files in directory: '{os.path.abspath(DATA_DIRECTORY)}'")
    for filename in os.listdir(DATA_DIRECTORY):
        if filename.endswith('.pt'):
            match = pattern.search(filename)
            if match:
                key = match.group(0)
                full_path = os.path.join(DATA_DIRECTORY, filename)
                grouped_files.setdefault(key, []).append(full_path)

    for file_group, file_paths in grouped_files.items():
        print(f"Processing group: {file_group} with {len(file_paths)} files")
        group_data = []
        for file_path in file_paths:
            try:
                data = np.array(torch.load(file_path))
            except:
                # TODO: Baseline has extra 4 data
                # import pdb;pdb.set_trace()
                temp = torch.load(file_path)
                standard_length = len(temp[1][0])
                baseline_length = len(temp[0][0])
                if baseline_length > standard_length:
                    temp[0][0] = temp[0][0][:standard_length]
                elif baseline_length < standard_length:
                    sup_length = standard_length - baseline_length
                    temp[0][0] = temp[0][0] + temp[0][0][-sup_length:]
                data = np.array(temp)
            group_data.append(data)
        
        sum_data = torch.tensor(np.concatenate(group_data, axis=1))  # 1, 10, 229, 1, 1
        # import pdb;pdb.set_trace()
        plot_resultses_ds(sum_data, BATCH_SIZE, METHODS, img_name=file_group, save_path=SAVE_PATH, DATA_NAME=DATA_NAME, SAVE_NPY=SAVE_NPY)