import torch
import matplotlib.pyplot as plt
from typing import List
import os
import datetime, re
import numpy as np

def plot_resultses_ds(resultses, batch_size:int,name:List[str],img_name,save_path)->None:
    plt.close()
    # 设置美观的样式
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # 创建图表和坐标轴
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # 使用颜色循环确保一致的颜色分配
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
    
    # 存储所有结果
    avg_resultses = []
    error_bars = []
    
    for results in resultses:
        if results[0][0].shape[1] != 1:
            raise NotImplementedError("Only support 1D data")
        
        # 批处理结果
        results = [[v.view(1) for v in result] for result in results]
        results = np.array(results)
        batched_results = []
        
        for i in range(batch_size, results.shape[1] + 1, batch_size):
            batched_results.append(np.max(results[:, i - batch_size:i], axis=1))
        
        batched_results = np.hstack(batched_results)
        max_till_now = np.maximum.accumulate(batched_results, axis=1)
        
        # 计算平均值和标准差
        avg_results = np.mean(max_till_now, axis=0)
        error_bar = np.std(max_till_now, axis=0)
        # error_bar = np.var(max_till_now, axis=0)
        
        avg_resultses.append(avg_results)
        error_bars.append(error_bar)
    
    # 绘制每个结果集
    for i, (avg_results, error_bar) in enumerate(zip(avg_resultses, error_bars)):
        # 主线条使用实线，更粗
        line, = ax.plot(
            avg_results, 
            label=name[i],
            linewidth=2.5,
            color=color_cycle[i % len(color_cycle)]
        )
        
        # 误差带使用匹配颜色但更浅的填充
        ax.fill_between(
            range(len(avg_results)),
            avg_results - error_bar,
            avg_results + error_bar,
            alpha=0.15,  # 降低透明度避免重叠混淆
            color=line.get_color(),
            edgecolor=None,  # 移除边界线
            linewidth=0      # 无边界线
        )
    
    # 添加图表元素
    ax.set_title(f'Performance Comparison: {img_name}', fontsize=14, pad=20)
    ax.set_xlabel('Batch Iteration', fontsize=12)
    ax.set_ylabel('Cumulative Maximum Value', fontsize=12)

    ax.set_ylim(0, 100) # 固定y的范围
    
    # 优化图例
    ax.legend(
        loc='lower right' if len(avg_resultses) < 5 else 'best',
        frameon=True,
        shadow=True,
        fancybox=True
    )
    
    # 设置网格
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # 调整布局
    plt.tight_layout()
    
    # 确保保存目录存在
    os.makedirs(save_path, exist_ok=True)
    
    # 保存高质量图像
    plt.savefig(
        os.path.join(save_path, f"{img_name}.png"),
        dpi=300,
        bbox_inches='tight'
    )
    plt.close()

if __name__ == "__main__":
    use_diverse = False
    wo_pretrain = False
    use_cluster = False

    
    # results_dir = "sci_results"
    # results_dir = "four_method_results"
    # results_dir = "diff_batch_size"
    # results_dir = "new_method_results"
    # results_dir = "compare_pretrain_results"
    results_dir = "cluster_loss_results"
    # results_dir = "extra_results"
    # results_dir = "hallucinations"
    # results_dir = "rag_results"


    if results_dir == "diff_batch_size":
        name_dir = "new_1_10" if use_diverse else "no_diverse_1_10"
    elif results_dir == "new_method_results":
        name_dir = f"no_pretrain_30" if wo_pretrain else f"old_40"
    elif results_dir == "extra_results":
        name_dir = f"no_pretrain_30" if wo_pretrain else f"old_30"
    elif results_dir == "compare_pretrain_results":
        name_dir = f"no_pretrain_40" if wo_pretrain else f"pretrain_40"
    elif results_dir == "cluster_loss_results":
        name_dir = "cluster_40" if use_cluster else "no_cluster_40"
    else:
        name_dir = "new_40" if use_diverse else "no_diverse_40"

    DATA_DIRECTORY = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/suzuki'
    BATCH_SIZE = int(name_dir.split("_")[-2]) if results_dir == "diff_batch_size" else 5
    if results_dir == "rag_results":
        METHODS = ['baseline', 'expert', 'rag']
    elif results_dir == "sci_results":
        METHODS = ['expert', 'rag', 'sci']
    elif results_dir == "four_method_results":
        METHODS = ['Baseline','expert','rag', 'sci']
    elif results_dir == "diff_batch_size":
        METHODS = ['Baseline','expert','rag', 'sci']
    elif results_dir == "new_method_results":
        METHODS = ['expert', 'sci', 'agglo', 'kmeans', 'sims']
    elif results_dir == "extra_results":
        METHODS = ['sims_ex', 'expert', 'sci', 'sims']
    elif results_dir == "compare_pretrain_results":
        METHODS = ['sims_ex', 'expert', 'sci']
    elif results_dir == "cluster_loss_results":
        METHODS = ['rag', 'sims_ex', 'sims_ex_cluster', 'expert']
        
    SAVE_PATH = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/plot_results/{results_dir}'

    SAVE_PATH = os.path.join(SAVE_PATH, name_dir)
    # SAVE_PATH = os.path.join(SAVE_PATH, "temp_test")

    os.makedirs(SAVE_PATH, exist_ok=True)

    grouped_files = {}
    pattern = re.compile(r'dh-(\d+)')

    print(f"Scanning for .pt files in directory: '{os.path.abspath(DATA_DIRECTORY)}'")
    for filename in os.listdir(DATA_DIRECTORY):
        if filename.endswith('.pt'):
            match = pattern.search(filename)
            if match:
                key = match.group(0)
                full_path = os.path.join(DATA_DIRECTORY, filename)
                grouped_files.setdefault(key, []).append(full_path)

    for file_group, file_paths in grouped_files.items():
        print(f"Processing group: {file_group} with {len(file_paths)} files")
        group_data = []
        for file_path in file_paths:
            try:
                data = np.array(torch.load(file_path))
            except:
                # TODO: Baseline has extra 4 data
                # import pdb;pdb.set_trace()
                temp = torch.load(file_path)
                standard_length = len(temp[1][0])
                baseline_length = len(temp[0][0])
                if baseline_length > standard_length:
                    temp[0][0] = temp[0][0][:standard_length]
                elif baseline_length < standard_length:
                    sup_length = standard_length - baseline_length
                    temp[0][0] = temp[0][0] + temp[0][0][-sup_length:]
                data = np.array(temp)
            group_data.append(data)
        
        sum_data = torch.tensor(np.concatenate(group_data, axis=1))  # 3, n, 50
        # import pdb;pdb.set_trace()
        plot_resultses_ds(sum_data, BATCH_SIZE, METHODS, img_name=file_group, save_path=SAVE_PATH)