
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import os, random
from loguru import logger
from tqdm import tqdm
import numpy as np

Suzuki_PATH = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/suzuki_50/searchspace.csv'
Suzuki_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/suzuki_50/cluster_pretrain_yields_merged.pt"
Suzuki_Partition_Order = ['ligand','base','solvent']
Suzuki_Rag_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/suzuki/dry_suzuki_rag_clustered_o3mini.json'
Suzuki_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/suzuki_50/handled_idx.npy"

# Arylation_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/arylation/new_searchspace.csv"
Arylation_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/arylation/name_searchspace.csv"
Arylation_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/arylation/cluster_pretrain_yields_merged.pt"
# Arylation_Partition_Order = ["Ligand_SMILES", "Base_SMILES", "Additive_SMILES", "Aryl_halide_SMILES"]
Arylation_Partition_Order = ["Ligand_SMILES", "Additive_SMILES", "Aryl_halide_SMILES", "Base_SMILES"]
# Arylation_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/arylation_smiles/summary.json"
Arylation_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/arylation/summary.json"
Arylation_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/arylation/handled_idx.npy"

Buchwald_1_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv/name_searchspace.csv"
Buchwald_1_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv/cluster_pretrain_yields_merged.pt"
Buchwald_1_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_1_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv/summary.json"
Buchwald_1_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv/handled_idx.npy"

Buchwald_2_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/name_searchspace.csv"
Buchwald_2_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/cluster_pretrain_yields_merged.pt"
Buchwald_2_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_2_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/summary.json"
Buchwald_2_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2ccccn2)cc1.csv/handled_idx.npy"

Buchwald_3_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2cccnc2)cc1.csv/name_searchspace.csv"
Buchwald_3_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/buchwald_Cc1ccc(Nc2cccnc2)cc1.csv/cluster_pretrain_yields_merged.pt"
Buchwald_3_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_3_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_Cc1ccc(Nc2cccnc2)cc1.csv/summary.json"
Buchwald_3_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_Cc1ccc(Nc2cccnc2)cc1.csv/handled_idx.npy"

Buchwald_4_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv/name_searchspace.csv"
Buchwald_4_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv/cluster_pretrain_yields_merged.pt"
Buchwald_4_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_4_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv/summary.json"
Buchwald_4_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv/handled_idx.npy"

Buchwald_5_PATH = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/name_searchspace.csv"
Buchwald_5_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/ChemBOMAS_results/eval_results/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/cluster_pretrain_yields_merged.pt"
Buchwald_5_Partition_Order = ["Ligand", "Base", "Reactant2", "Additive"]
Buchwald_5_Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/summary.json"
Buchwald_5_idx_npy = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv/handled_idx.npy"

data_harder = 0

RUN_DH = True

DIVERSE=True

os.environ["USE_WEIGHT_RANDOM_SAMPLE"] = "YES"

# suzuki arylation
# buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv
# buchwald_Cc1ccc(Nc2ccccn2)cc1.csv
# buchwald_Cc1ccc(Nc2cccnc2)cc1.csv
# buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv
# buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv

DATA_NAME = "suzuki"
# DATA_NAME = "arylation"
# DATA_NAME = "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv"
# DATA_NAME = "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv"
# DATA_NAME = "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv"
# DATA_NAME = "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv"
# DATA_NAME = "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"

partition_maps = {
    "suzuki":[Suzuki_Rag_Json, Suzuki_Partition_Order, 50, Suzuki_idx_npy, 5, Suzuki_Yield_Prediction],
    "arylation":[Arylation_Rag_Json, Arylation_Partition_Order, 34, Arylation_idx_npy, 3, Arylation_Yield_Prediction],
    "buchwald_Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv":[Buchwald_1_Rag_Json, Buchwald_1_Partition_Order, 7, Buchwald_1_idx_npy, 1, Buchwald_1_Yield_Prediction],
    "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv":[Buchwald_2_Rag_Json, Buchwald_2_Partition_Order, 7, Buchwald_2_idx_npy, 1, Buchwald_2_Yield_Prediction],
    "buchwald_Cc1ccc(Nc2cccnc2)cc1.csv":[Buchwald_3_Rag_Json, Buchwald_3_Partition_Order, 7, Buchwald_3_idx_npy, 1, Buchwald_3_Yield_Prediction],
    "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv":[Buchwald_4_Rag_Json, Buchwald_4_Partition_Order, 7, Buchwald_4_idx_npy, 1, Buchwald_4_Yield_Prediction],
    "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv":[Buchwald_5_Rag_Json, Buchwald_5_Partition_Order, 7, Buchwald_5_idx_npy, 1, Buchwald_5_Yield_Prediction],
}

NUM_INIT_SAMPLE = partition_maps[DATA_NAME][2]

def return_dataset(data_name:str):
    if "suzuki" in data_name:
        dataset = BoData(BoData.read_suzuki_exp_data(Suzuki_PATH))
        dataset.load_data_prediction(Suzuki_Yield_Prediction)
        dataset.use_test_set(Suzuki_Yield_Prediction)
    elif "arylation" in data_name:
        dataset = BoData(BoData.read_arylation_exp_data(Arylation_PATH))
        dataset.load_data_prediction(Arylation_Yield_Prediction)
        dataset.use_test_set(Arylation_Yield_Prediction)
    elif "Cc1ccc(Nc2ccc(C(F)(F)F)cc2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_1_PATH))
        dataset.load_data_prediction(Buchwald_1_Yield_Prediction)
        dataset.use_test_set(Buchwald_1_Yield_Prediction)
    elif "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_2_PATH))
        dataset.load_data_prediction(Buchwald_2_Yield_Prediction)
        dataset.use_test_set(Buchwald_2_Yield_Prediction)
    elif "Cc1ccc(Nc2cccnc2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_3_PATH))
        dataset.load_data_prediction(Buchwald_3_Yield_Prediction)
        dataset.use_test_set(Buchwald_3_Yield_Prediction)
    elif "CCc1ccc(Nc2ccc(C)cc2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_4_PATH))
        dataset.load_data_prediction(Buchwald_4_Yield_Prediction)
        dataset.use_test_set(Buchwald_4_Yield_Prediction)
    elif "COc1ccc(Nc2ccc(C)cc2)cc1.csv" in data_name:
        dataset = BoData(BoData.read_buchwald_data(Buchwald_5_PATH))
        dataset.load_data_prediction(Buchwald_5_Yield_Prediction)
        dataset.use_test_set(Buchwald_5_Yield_Prediction)
        
    if RUN_DH:
        # dataset.make_data_harder(data_harder)
        dataset.make_data_harder_from_npy(partition_maps[DATA_NAME][3])
        logger.info("="*50+"Running DH"+"="*50)
    return dataset

def return_partition(data_name:str):
    expert_partition = json.load(open(partition_maps[data_name][0]))
    order = partition_maps[data_name][1]
    return expert_partition,order

def run_all(repeat_time:int,iteration:int,data_name:str,batch_size:int):

    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(data_name)
        expert_partition,order = return_partition(data_name)
        mcts = MCTS(
            dataset=dataset,
            sampler=BOSampler(dataset),
            batch_size=batch_size,
            variable_nums=len(partition_maps[DATA_NAME][1]),
            n_candidates=20,
            use_diverse_sample=DIVERSE,
            num_init_samples=NUM_INIT_SAMPLE
            )
        
        observed_idxes = torch.load(partition_maps[DATA_NAME][5], weights_only=False,map_location=torch.device('cpu'))['train_idx']
        obs = mcts.obersved_init_tree(expert_partition,order,observed_idxes,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def main(args, random_seed, dh):
    global data_harder
    data_harder = dh
    data_name = DATA_NAME

    name_dir = f"exp_{args.iteration}_init_{NUM_INIT_SAMPLE}_{"diverse" if DIVERSE else "origin"}_observed_init"
    
    results_dir = "observed_init_exp_results"

    results_path = os.path.join(f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/', data_name)
    os.makedirs(results_path,exist_ok=True)

    resultses = []

    repeat_time = args.repeat_time
    iteration = args.iteration
    batch_size = partition_maps[DATA_NAME][4]

    method_name = "chembomas"
    logger.info("="*50+f"Running {method_name}"+"="*50)
    resultses.append(run_all(repeat_time,iteration,data_name,batch_size))
    
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    # save raw results
    pt_path = os.path.join(results_path,rf'{random_seed}_{repeat_time}+{iteration}_dh-{data_harder}_raw_results_'+now_time+'.pt')
    torch.save(resultses,pt_path)
    logger.info(f"raw results saved to {pt_path}")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
if __name__ == '__main__':
    # set seed
    # random_seed = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    random_seed = [100, 200, 300, 400, 500]
    # random_seed = [600, 700, 800, 900, 1000]

    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_time',type=int,default=1)
    parser.add_argument('--iteration',type=int,default=40)
    args = parser.parse_args()

    # data_harder
    # data_harder_list = [50]
    data_harder_list = [0]

    start_time = time.time()
    for dh in data_harder_list:
        for seed in random_seed:
            # setup_seed(seed)
            # setup_seed(None)
            main(args, seed, dh)
    end_time = time.time()
    logger.success(f"total time: {end_time-start_time}")