
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import os, random
from loguru import logger
from tqdm import tqdm
import numpy as np

Searchspace_Path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/{dataset}/name_searchspace.csv"
Rag_Json = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/sci_files/{dataset}/partition.json"
Yield_prediction = "/mnt/shared-storage-user/caipengxiang/H200-ai4chem/Sci_results/eval_results/{dataset}/cluster_no_pretrain_yields_merged.pt"
Partition_order = {
    "crossed_barrel": ["theta","t","r","n"],
    "dye_lasers": ["frag_c", "frag_b", "frag_a"],
    "lnp3": ["solid_lipid", "solid_lipid_input", "liquid_lipid_input", "surfractant_input", "drug_input"],
    "perovskites": ["anion", "cation", "organic"],
    "redoxmers": ["r5_smiles", "r3_smiles", "r4_smiles", "r1_smiles"]
}
x_y_split_cols = {
    "crossed_barrel": [Partition_order["crossed_barrel"],["toughness"]],
    "dye_lasers": [Partition_order["dye_lasers"],["fluo_rate"]],
    "lnp3": [Partition_order["lnp3"],["encap_efficiency"]],
    "perovskites": [Partition_order["perovskites"],["hse_gap"]],
    "redoxmers": [Partition_order["redoxmers"],["abs_lam"]]
}

data_harder = 0

RUN_DH = False

DIVERSE=True

os.environ["USE_WEIGHT_RANDOM_SAMPLE"] = "YES"


DATA_NAME = "redoxmers"

# ["crossed_barrel", "dye_lasers", "lnp3", "perovskites", "redoxmers"]

partition_maps = {
    "crossed_barrel":[None, None, 6, None, 1],
    "dye_lasers":[None, None, 35, None, 3],
    "lnp3":[None, None, 8, None, 1],
    "perovskites":[None, None, 2, None, 1],
    "redoxmers":[None, None, 17, None, 2],
}

NUM_INIT_SAMPLE = partition_maps[DATA_NAME][2]

def return_dataset(data_name:str):
    dataset = BoData(BoData.read_sci_data(Searchspace_Path.format(dataset=data_name), x_y_split=x_y_split_cols[data_name]))
    dataset.load_data_prediction(Yield_prediction.format(dataset=data_name))
    dataset.use_test_set(Yield_prediction.format(dataset=data_name))
        
    if RUN_DH:
        # dataset.make_data_harder(data_harder)
        dataset.make_data_harder_from_npy(partition_maps[DATA_NAME][3])
        logger.info("="*50+"Running DH"+"="*50)
    return dataset

def return_partition(data_name:str):
    expert_partition = json.load(open(Rag_Json.format(dataset=data_name)))
    order = Partition_order[data_name]
    return expert_partition,order

def run_all(repeat_time:int,iteration:int,data_name:str,batch_size:int):

    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(data_name)
        expert_partition,order = return_partition(data_name)
        mcts = MCTS(
            dataset=dataset,
            sampler=BOSampler(dataset),
            batch_size=batch_size,
            variable_nums=len(Partition_order[data_name]),
            n_candidates=20,
            use_diverse_sample=DIVERSE,
            num_init_samples=NUM_INIT_SAMPLE
            )
        
        obs = mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def main(args, random_seed, dh):
    global data_harder
    data_harder = dh
    data_name = DATA_NAME

    name_dir = f"exp_{args.iteration}_init_{NUM_INIT_SAMPLE}_{"diverse" if DIVERSE else "origin"}_sci_run"
    
    results_dir = "sci_run_exp_results"

    results_path = os.path.join(f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/', data_name)
    os.makedirs(results_path,exist_ok=True)

    resultses = []

    repeat_time = args.repeat_time
    iteration = args.iteration
    batch_size = partition_maps[DATA_NAME][4]

    method_name = "chembomas"
    logger.info("="*50+f"Running {method_name}"+"="*50)
    resultses.append(run_all(repeat_time,iteration,data_name,batch_size))
    
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    # save raw results
    pt_path = os.path.join(results_path,rf'{random_seed}_{repeat_time}+{iteration}_dh-{data_harder}_raw_results_'+now_time+'.pt')
    torch.save(resultses,pt_path)
    logger.info(f"raw results saved to {pt_path}")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
if __name__ == '__main__':
    # set seed
    # random_seed = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    random_seed = [100, 200, 300, 400, 500]
    # random_seed = [600, 700, 800, 900, 1000]

    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_time',type=int,default=1)
    parser.add_argument('--iteration',type=int,default=40)
    args = parser.parse_args()

    # data_harder
    # data_harder_list = [50]
    data_harder_list = [0]

    start_time = time.time()
    for dh in data_harder_list:
        for seed in random_seed:
            # setup_seed(seed)
            # setup_seed(None)
            main(args, seed, dh)
    end_time = time.time()
    logger.success(f"total time: {end_time-start_time}")