import torch
import matplotlib.pyplot as plt
from typing import List
import os
import datetime, re
import numpy as np

def cal_mean_std(resultses, batch_size:int,name:List[str],img_name,save_path, DATA_NAME=None, SAVE_NPY=True)->None:
    
    # 存储所有结果
    avg_resultses = []
    error_bars = []
    
    for idx, results in enumerate(resultses):
        if results[0][0].shape[1] != 1:
            raise NotImplementedError("Only support 1D data")
        
        # 批处理结果
        results = [[v.view(1) for v in result] for result in results]
        results = np.array(results)
        batched_results = []
        
        for i in range(batch_size, results.shape[1] + 1, batch_size):
            batched_results.append(np.max(results[:, i - batch_size:i], axis=1))
        
        batched_results = np.hstack(batched_results)
        max_till_now = np.maximum.accumulate(batched_results, axis=1)
        
        # 计算平均值和标准差
        avg_results = np.mean(max_till_now, axis=0)
        error_bar = np.std(max_till_now, axis=0)
        
        if SAVE_NPY:
            os.makedirs("weights_sample_mean_std_save/weights_sample_chembomas", exist_ok=True)
            np.save(f"weights_sample_mean_std_save/weights_sample_chembomas/weights_sample_chembomas_{DATA_NAME}_{name[idx]}_mean.npy", avg_results)
            np.save(f"weights_sample_mean_std_save/weights_sample_chembomas/weights_sample_chembomas_{DATA_NAME}_{name[idx]}_std.npy", error_bar)


if __name__ == "__main__":
    results_dir = "weights_sample_exp_results"
    
    DATA_NAME = "suzuki"
    # DATA_NAME = "arylation"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv"
    DATA_NAME = "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv"
    # DATA_NAME = "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv"
    # DATA_NAME = "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv"
    # DATA_NAME = "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"
    
    SAVE_NPY = True
    
    partition_maps = {
        "suzuki": 50,
        "arylation": 34,
        "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv": 7,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": 7,
        "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv": 7,
        "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv": 7,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": 7
    }
    BATCH_SIZE_MAPS = {
        "suzuki": 5,
        "arylation": 3,
        "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv": 1,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": 1,
        "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv": 1,
        "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv": 1,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": 1
    }

    NUM_INIT_SAMPLE = partition_maps[DATA_NAME]

    name_dir = f"exp_40_init_{partition_maps[DATA_NAME]}_diverse_weights_sample"

    DATA_DIRECTORY = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/{DATA_NAME}'

    BATCH_SIZE = BATCH_SIZE_MAPS[DATA_NAME]

    METHODS = ['weights_sample_chembomas']
        
    SAVE_PATH = f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/plot_results/{results_dir}'

    SAVE_PATH = os.path.join(SAVE_PATH, name_dir, DATA_NAME)
    # SAVE_PATH = os.path.join(SAVE_PATH, "temp_test")

    os.makedirs(SAVE_PATH, exist_ok=True)

    grouped_files = {}
    pattern = re.compile(r'dh-(\d+)')

    print(f"Scanning for .pt files in directory: '{os.path.abspath(DATA_DIRECTORY)}'")
    for filename in os.listdir(DATA_DIRECTORY):
        if filename.endswith('.pt'):
            match = pattern.search(filename)
            if match:
                key = match.group(0)
                full_path = os.path.join(DATA_DIRECTORY, filename)
                grouped_files.setdefault(key, []).append(full_path)

    for file_group, file_paths in grouped_files.items():
        print(f"Processing group: {file_group} with {len(file_paths)} files")
        group_data = []
        for file_path in file_paths:
            try:
                data = np.array(torch.load(file_path))
            except:
                # TODO: Baseline has extra 4 data
                # import pdb;pdb.set_trace()
                temp = torch.load(file_path)
                standard_length = len(temp[1][0])
                baseline_length = len(temp[0][0])
                if baseline_length > standard_length:
                    temp[0][0] = temp[0][0][:standard_length]
                elif baseline_length < standard_length:
                    sup_length = standard_length - baseline_length
                    temp[0][0] = temp[0][0] + temp[0][0][-sup_length:]
                data = np.array(temp)
            group_data.append(data)
        
        sum_data = torch.tensor(np.concatenate(group_data, axis=1))  # 1, 10, 229, 1, 1
        # import pdb;pdb.set_trace()
        cal_mean_std(sum_data, BATCH_SIZE, METHODS, img_name=file_group, save_path=SAVE_PATH, DATA_NAME=DATA_NAME, SAVE_NPY=SAVE_NPY)