import torch
import matplotlib.pyplot as plt
from typing import List
import os
import datetime, re
import numpy as np

def cal_mean_std(resultses, batch_size:int,name:List[str],img_name,save_path, DATA_NAME=None, SAVE_NPY=True, dir_name="diff_data_scale_mean_std_save", method_name="")->None:
    
    # 存储所有结果
    avg_resultses = []
    error_bars = []
    
    for idx, results in enumerate(resultses):
        if results[0][0].shape[1] != 1:
            raise NotImplementedError("Only support 1D data")
        
        # 批处理结果
        results = [[v.view(1) for v in result] for result in results]
        results = np.array(results)
        batched_results = []
        
        for i in range(batch_size, results.shape[1] + 1, batch_size):
            batched_results.append(np.max(results[:, i - batch_size:i], axis=1))
        
        batched_results = np.hstack(batched_results)
        max_till_now = np.maximum.accumulate(batched_results, axis=1)
        
        # 计算平均值和标准差
        avg_results = np.mean(max_till_now, axis=0)
        error_bar = np.std(max_till_now, axis=0)
        
        if SAVE_NPY:
            os.makedirs(f"{dir_name}/{name[idx]}", exist_ok=True)
            np.save(f"{dir_name}/{name[idx]}/{name[idx]}_{DATA_NAME}_{name[idx]}_mean.npy", avg_results)
            np.save(f"{dir_name}/{name[idx]}/{name[idx]}_{DATA_NAME}_{name[idx]}_std.npy", error_bar)
if __name__ == "__main__":
    results_dir = "diff_data_scale_exp_results_without_mcts"
    
    # DATA_NAME = "suzuki"
    # DATA_NAME = "arylation"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv"
    DATA_NAME = "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv"
    # DATA_NAME = "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv"
    # DATA_NAME = "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"
    
    SAVE_NPY = True
    
    data_scales = [0.25, 0.5, 1.0, 2, 4]
    
    partition_maps = {
        "suzuki": 50,
        "arylation": 34,
        "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv": 7,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": 7,
        "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv": 7,
        "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv": 7,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": 7
    }
    BATCH_SIZE_MAPS = {
        "suzuki": 5,
        "arylation": 3,
        "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv": 1,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": 1,
        "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv": 1,
        "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv": 1,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": 1
    }

    NUM_INIT_SAMPLE = partition_maps[DATA_NAME]

    for data_scale in data_scales:
        name_dir = f"exp_40_init_{partition_maps[DATA_NAME]}_diverse_diff_data_scale_without_mcts"

        DATA_DIRECTORY = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/{DATA_NAME}_{data_scale}'

        BATCH_SIZE = BATCH_SIZE_MAPS[DATA_NAME]
        
        SAVE_PATH = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/plot_results/{results_dir}'

        SAVE_PATH = os.path.join(SAVE_PATH, name_dir, DATA_NAME)
        # SAVE_PATH = os.path.join(SAVE_PATH, "temp_test")
        
        METHODS = [f'data_scale_{data_scale}']

        os.makedirs(SAVE_PATH, exist_ok=True)
        grouped_files = {}
        pattern = re.compile(r'scale_(\d+\.?\d*)')

        print(f"Scanning for .pt files in directory: '{os.path.abspath(DATA_DIRECTORY)}'")
        for filename in os.listdir(DATA_DIRECTORY):
            if filename.endswith('.pt'):
                match = pattern.search(filename)
                if match:
                    key = match.group(0)
                    full_path = os.path.join(DATA_DIRECTORY, filename)
                    grouped_files.setdefault(key, []).append(full_path)

        for file_group, file_paths in grouped_files.items():
            print(f"Processing group: {file_group} with {len(file_paths)} files")
            group_data = []
            for file_path in file_paths:
                try:
                    data = np.array(torch.load(file_path))
                except:
                    # TODO: Baseline has extra 4 data
                    # import pdb;pdb.set_trace()
                    temp = torch.load(file_path)
                    standard_length = len(temp[1][0])
                    baseline_length = len(temp[0][0])
                    if baseline_length > standard_length:
                        temp[0][0] = temp[0][0][:standard_length]
                    elif baseline_length < standard_length:
                        sup_length = standard_length - baseline_length
                        temp[0][0] = temp[0][0] + temp[0][0][-sup_length:]
                    data = np.array(temp)
                group_data.append(data)
            
            sum_data = torch.tensor(np.concatenate(group_data, axis=1))  # 1, 10, 229, 1, 1
            # import pdb;pdb.set_trace()
            cal_mean_std(sum_data, BATCH_SIZE, METHODS, img_name=file_group, save_path=SAVE_PATH, DATA_NAME=DATA_NAME, SAVE_NPY=SAVE_NPY)