
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import os, random
from loguru import logger
from tqdm import tqdm
import numpy as np
Suzuki_PATH = r'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/exp/suzuki/experiment_index.csv'


Suzuki_Yield_Prediction = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/suzuki_60/pretrain_yields_merged.pt"
Suzuki_Yield_Prediction_wo_pretrain = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/data4regression/suzuki_60/no_pretrained_yield_merged.pt"
Suzuki_Partition_Order = ['ligand','base','solvent']

Suzuki_Sci_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/suzuki/dry_suzuki_rag_clustered_sci.json'
Suzuki_sims_Ex_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_encoder/json_files/summary/similarity_clusters_Ex.json'
Suzuki_Expert_Partition = r'/fs-computility/ai4phys/shared/caipengxiang/suzuki_partition.json'

data_harder = 40
wo_pretrain = True

partition_maps = {
    "baseline_with_peseudo":None,
    # "sims_ex":Suzuki_sims_Ex_Json,
    # "expert":Suzuki_Expert_Partition,
    # "sci":Suzuki_Sci_Json,
}

def return_dataset(data_name:str):
    dataset = BoData(BoData.read_suzuki_exp_data(Suzuki_PATH))
    if wo_pretrain:
        dataset.load_data_prediction(Suzuki_Yield_Prediction_wo_pretrain)
        dataset.use_test_set(Suzuki_Yield_Prediction_wo_pretrain)
    else:
        dataset.load_data_prediction(Suzuki_Yield_Prediction)
        dataset.use_test_set(Suzuki_Yield_Prediction)
    dataset.make_data_harder(data_harder)
    return dataset

def return_partition(method_name:str):
    expert_partition = json.load(open(partition_maps[method_name]))
    order = Suzuki_Partition_Order
    return expert_partition,order

def run_baseline(repeat_time:int,iteration:int,method_name:str,batch_size:int):
    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(method_name)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20)
        obs = mcts.init_tree(dont_build_tree=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.baseline_search()
            result.extend(obs)
        results.append(result)
    return results

def run_baseline_with_pseudo(repeat_time:int,iteration:int,method_name:str,batch_size:int):
    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(method_name)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20)
        obs = mcts.init_tree(dont_build_tree=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.baseline_search(pseudo_label=True)
            result.extend(obs)
        results.append(result)
    return results

def run_all(repeat_time:int,iteration:int,method_name:str,batch_size:int):

    if method_name == "baseline":
        return run_baseline(repeat_time,iteration,method_name,batch_size)
    elif method_name == "baseline_with_peseudo":
        return run_baseline_with_pseudo(repeat_time,iteration,method_name,batch_size)

    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(method_name)
        expert_partition,order = return_partition(method_name)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20)
        obs = mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def main(args, random_seed, dh):
    global wo_pretrain
    wo_pretrain = args.wo_pretrain
    global data_harder
    data_harder = dh
    data_name = args.data_name

    name_dir = f"no_pretrain_baseline_pseudo_{args.iteration}" if wo_pretrain else f"pretrain_baseline_pseudo_{args.iteration}"
    
    results_dir = "compare_pretrain_results"

    results_path = os.path.join(f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/', data_name)
    os.makedirs(results_path,exist_ok=True)

    resultses = []

    repeat_time = args.repeat_time
    iteration = args.iteration
    batch_size = args.batch_size

    for method_name in partition_maps.keys():
        logger.info("="*50+f"Running {method_name}"+"="*50)
        resultses.append(run_all(repeat_time,iteration,method_name,batch_size))
    
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    # save raw results
    pt_path = os.path.join(results_path,rf'{random_seed}_{repeat_time}+{iteration}_dh-{data_harder}_raw_results_'+now_time+'.pt')
    torch.save(resultses,pt_path)
    logger.info(f"raw results saved to {pt_path}")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
if __name__ == '__main__':
    # set seed
    random_seed = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]

    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_time',type=int,default=1)
    parser.add_argument('--iteration',type=int,default=40)
    parser.add_argument('--data_name',type=str,default='suzuki')
    parser.add_argument('--batch_size',type=int,default=5)
    parser.add_argument('--wo_pretrain',type=int,default=1)
    args = parser.parse_args()

    # data_harder
    data_harder_list = [40]

    start_time = time.time()
    for dh in data_harder_list:
        for seed in random_seed:
            setup_seed(seed)
            main(args, seed, dh)
    end_time = time.time()
    logger.success(f"total time: {end_time-start_time}")