
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import os, random
from loguru import logger
from tqdm import tqdm
import numpy as np

Suzuki_PATH = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/suzuki_50/searchspace.csv'
Suzuki_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/suzuki_50/cluster_pretrain_yields_merged.pt"
Suzuki_Partition_Order = ['ligand','base','solvent']
Suzuki_Rag_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/suzuki/dry_suzuki_rag_clustered_o3mini.json'
Suzuki_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/suzuki_50/handled_idx.npy"

Arylation_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/arylation/name_searchspace.csv"
Arylation_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/arylation/cluster_pretrain_yields_merged.pt"
Arylation_Partition_Order = ["Ligand_SMILES", "Additive_SMILES", "Aryl_halide_SMILES", "Base_SMILES"]
Arylation_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/arylation/summary.json"
Arylation_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/arylation/handled_idx.npy"

Buchwald_2_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/name_searchspace.csv"
Buchwald_2_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/Exp_AB_results/eval_results/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/grouped_exp_yields_merged.pt"
Buchwald_2_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_2_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/summary.json"
Buchwald_2_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/handled_idx.npy"

Buchwald_5_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/name_searchspace.csv"
Buchwald_5_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/Exp_AB_results/eval_results/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/grouped_exp_yields_merged.pt"
Buchwald_5_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_5_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/summary.json"
Buchwald_5_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/handled_idx.npy"

def return_rag_ab_json(dataset):
    arylation = {
        "expert": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/arylation/arylation_expert_partition.json",
        "ours": Arylation_Rag_Json,
        "embed": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results/summary/arylation_clusters.json"
    }
    suzuki = {
        "expert": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/partition/suzuki_partition.json",
        "ours": Suzuki_Rag_Json,
        "embed": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results/summary/suzuki_50_clusters.json",
    }
    buchwald_2 = {
        "expert": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/buchwald-2_expert_partition.json",
        "ours": Buchwald_2_Rag_Json,
        "embed": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results/summary/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv_clusters.json"
    }
    buchwald_5 = {
        "expert": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/buchwald-5_expert_partition.json",
        "ours": Buchwald_5_Rag_Json,
        "embed": "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/exp_embed_cluster_results/summary/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv_clusters.json"
    }
    temp = {
        "suzuki": suzuki,
        "arylation": arylation,
        "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv": buchwald_2,
        "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv": buchwald_5,
    }
    return temp[dataset]

data_harder = 0

RUN_DH = True

DIVERSE=True

os.environ["USE_WEIGHT_RANDOM_SAMPLE"] = "YES"

# suzuki arylation
# buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv
# buchwald_Cc1ccc(Nc2ccccn2)cc1.csv
# buchwald_Cc1ccc(Nc2cccnc2)cc1.csv
# buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv
# buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv

DATA_NAME = "suzuki"
DATA_NAME = "arylation"
# DATA_NAME = "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv"
DATA_NAME = "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv"
# DATA_NAME = "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv"
# DATA_NAME = "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv"
# DATA_NAME = "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"

partition_maps = {
    "suzuki":[Suzuki_Rag_Json, Suzuki_Partition_Order, 50, Suzuki_idx_npy, 5],
    "arylation":[Arylation_Rag_Json, Arylation_Partition_Order, 34, Arylation_idx_npy, 3],
    "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv":[Buchwald_2_Rag_Json, Buchwald_2_Partition_Order, 7, Buchwald_2_idx_npy, 1],
    "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv":[Buchwald_5_Rag_Json, Buchwald_5_Partition_Order, 7, Buchwald_5_idx_npy, 1],
}

NUM_INIT_SAMPLE = partition_maps[DATA_NAME][2]

def return_dataset(data_name:str):
    if "suzuki" in data_name:
        dataset = BoData(BoData.read_suzuki_exp_data(Suzuki_PATH))
        dataset.load_data_prediction(Suzuki_Yield_Prediction)
        dataset.use_test_set(Suzuki_Yield_Prediction)
    elif "arylation" in data_name:
        dataset = BoData(BoData.read_arylation_exp_data(Arylation_PATH))
        dataset.load_data_prediction(Arylation_Yield_Prediction)
        dataset.use_test_set(Arylation_Yield_Prediction)
    elif "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_2_PATH))
        dataset.load_data_prediction(Buchwald_2_Yield_Prediction)
        dataset.use_test_set(Buchwald_2_Yield_Prediction)
    elif "COc1ccc(Nc2ccc(C)cc2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_5_PATH))
        dataset.load_data_prediction(Buchwald_5_Yield_Prediction)
        dataset.use_test_set(Buchwald_5_Yield_Prediction)
        
    if RUN_DH:
        # dataset.make_data_harder(data_harder)
        dataset.make_data_harder_from_npy(partition_maps[DATA_NAME][3])
        logger.info("="*50+"Running DH"+"="*50)
    return dataset

def return_partition(data_name:str):
    expert_partition = json.load(open(partition_maps[data_name][0]))
    order = partition_maps[data_name][1]
    return expert_partition,order

def run_all(repeat_time:int,iteration:int,data_name:str,batch_size:int, expert_partition):

    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(data_name)
        _,order = return_partition(data_name)
        mcts = MCTS(
            dataset=dataset,
            sampler=BOSampler(dataset),
            batch_size=batch_size,
            variable_nums=len(partition_maps[DATA_NAME][1]),
            n_candidates=20,
            use_diverse_sample=DIVERSE,
            num_init_samples=NUM_INIT_SAMPLE
            )
        
        obs = mcts.pseudo_init_tree(json.load(open(expert_partition)),order,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def main(args, random_seed, dh):
    global data_harder
    data_harder = dh
    data_name = DATA_NAME

    

    for config, expert_partition in return_rag_ab_json(DATA_NAME).items():
        if config != "embed":
            continue
        logger.success(f"Running : {config}")
        name_dir = f"exp_{args.iteration}_init_{NUM_INIT_SAMPLE}_{"diverse" if DIVERSE else "origin"}_diff_cluster_{config}"
    
        results_dir = "diff_cluster_results"

        results_path = os.path.join(f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/', data_name)
        os.makedirs(results_path,exist_ok=True)
        
        resultses = []

        repeat_time = args.repeat_time
        iteration = args.iteration
        batch_size = partition_maps[DATA_NAME][4]

        method_name = f"chembomas_{config}"
        logger.info("="*50+f"Running {method_name}"+"="*50)
        resultses.append(run_all(repeat_time,iteration,data_name,batch_size,expert_partition))
        
        now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
        # save raw results
        pt_path = os.path.join(results_path,rf'{random_seed}_{repeat_time}+{iteration}_dh-{data_harder}_raw_results_'+now_time+'.pt')
        torch.save(resultses,pt_path)
        logger.info(f"raw results saved to {pt_path}")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
if __name__ == '__main__':
    # set seed
    # random_seed = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    random_seed = [100, 200, 300, 400, 500]
    # random_seed = [600, 700, 800, 900, 1000]

    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_time',type=int,default=1)
    parser.add_argument('--iteration',type=int,default=40)
    args = parser.parse_args()

    # data_harder
    # data_harder_list = [50]
    data_harder_list = [0]

    start_time = time.time()
    for dh in data_harder_list:
        for seed in random_seed:
            # setup_seed(seed)
            # setup_seed(None)
            main(args, seed, dh)
    end_time = time.time()
    logger.success(f"total time: {end_time-start_time}")