
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import os, random
from loguru import logger
from tqdm import tqdm
import numpy as np
Suzuki_PATH = r'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/exp/suzuki/experiment_index.csv'

Suzuki_Expert_Partition = r'/fs-computility/ai4phys/shared/caipengxiang/suzuki_partition.json'
Suzuki_Yield_Prediction = r'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/data/pred_value/suzuki/yields_suzuki-miyaura_4_llama_regression.pt'
Suzuki_Partition_Order = ['ligand','base','solvent']    # tofix
Suzuki_Rag_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/suzuki/dry_suzuki_rag_clustered_o3mini.json'
Suzuki_Sci_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/suzuki/dry_suzuki_rag_clustered_sci.json'
# Suzuki_Rag_Json = '/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/Rag-Cluster/json_files/suzuki/dry_suzuki_rag_clustered_o3.json'

data_harder = 10

def return_dataset(data_name:str):
    dataset = BoData(BoData.read_suzuki_exp_data(Suzuki_PATH))
    dataset.load_data_prediction(Suzuki_Yield_Prediction)
    dataset.use_test_set(Suzuki_Yield_Prediction)
    dataset.make_data_harder(data_harder)
    return dataset

def return_expert_partition(data_name:str):
    expert_partition = json.load(open(Suzuki_Expert_Partition))
    order = Suzuki_Partition_Order

    return expert_partition,order

def return_rag_partition(data_name:str):
    rag_json = json.load(open(Suzuki_Rag_Json))
    order = Suzuki_Partition_Order

    return rag_json,order

def return_sci_partition(data_name:str):
    sci_json = json.load(open(Suzuki_Sci_Json))
    order = Suzuki_Partition_Order

    return sci_json,order

def run_baseline(repeat_time:int,iteration:int,data_name:str,batch_size:int,use_diverse_sample:bool=False):
    print("="*50, "Baseline", "="*50)
    results = []
    for _ in range(repeat_time):
        # print(f"repeat time: {_}")
        result = []
        dataset = return_dataset(data_name)
        # dataset.make_data_harder(data_harder)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20,use_diverse_sample=use_diverse_sample)
        obs = mcts.init_tree(dont_build_tree=True)
        # print(f"    random sampling done, max value: {max(obs)}")
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.baseline_search()
            result.extend(obs)
            # print(f"    iteration {i} done, max value: {max(result)}")
        results.append(result)
    
    return results

def run_all(repeat_time:int,iteration:int,data_name:str,batch_size:int,use_diverse_sample:bool=False):
    print("="*50, "All", "="*50)
    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(data_name)
        expert_partition,order = return_expert_partition(data_name)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20,use_diverse_sample=use_diverse_sample)
        obs = mcts.diverse_init_tree(expert_partition,order) if use_diverse_sample else mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def run_all_with_rag(repeat_time:int,iteration:int,data_name:str,batch_size:int,use_diverse_sample:bool=False):
    print("="*50, "RAG", "="*50)
    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(data_name)
        expert_partition,order = return_rag_partition(data_name)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20,use_diverse_sample=use_diverse_sample)
        obs = mcts.diverse_init_tree(expert_partition,order) if use_diverse_sample else mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def run_all_with_sci(repeat_time:int,iteration:int,data_name:str,batch_size:int,use_diverse_sample:bool=False):
    print("="*50, "Sci", "="*50)
    results = []
    for _ in range(repeat_time):
        result = []
        dataset = return_dataset(data_name)
        expert_partition,order = return_sci_partition(data_name)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset),batch_size=batch_size,variable_nums=len(Suzuki_Partition_Order),n_candidates=20,use_diverse_sample=use_diverse_sample)
        obs = mcts.diverse_init_tree(expert_partition,order) if use_diverse_sample else mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        result.extend(obs)
        for i in tqdm(range(1,iteration)):
            obs = mcts.search(pseudo_label=True, iteration_index=i)
            result.extend(obs)
        results.append(result)
    return results

def main(args, random_seed, dh, use_diverse_sample:bool=False):
    global data_harder
    data_harder = dh
    data_name = args.data_name

    # name_dir = "new" if use_diverse_sample else "no_diverse"
    # name_dir = f"new_{args.iteration}" if use_diverse_sample else f"no_diverse_{args.iteration}"
    name_dir = f"new_{args.batch_size}_{args.iteration}" if use_diverse_sample else f"no_diverse_{args.batch_size}_{args.iteration}"
    # results_dir = "rag_results"
    # results_dir = "sci_results"
    # results_dir = "four_method_results"
    results_dir = "diff_batch_size"
    # results_dir = "hallucinations"

    results_path = os.path.join(f'/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/mas/code/{results_dir}/{name_dir}/', data_name)
    os.makedirs(results_path,exist_ok=True)

    resultses = []

    repeat_time = args.repeat_time
    iteration = args.iteration
    exp_name = args.exp_name
    batch_size = args.batch_size
    logger.info(f"Exp_name: {exp_name}")

    func_map = {
        "Baseline":run_baseline,
        "all":run_all,
        "rag":run_all_with_rag,
        # "agent":run_all_with_agent,
        "sci":run_all_with_sci,
    }

    for exp in exp_name:
        logger.info(f"Running {exp}")
        resultses.append(func_map[exp](repeat_time,iteration,data_name,batch_size,use_diverse_sample=use_diverse_sample))


    rag_model_name = "no_rag"
    if 'rag' in exp_name or 'agent' in exp_name:
        rag_model_name = Suzuki_Rag_Json.split('/')[-1].split('.')[0].split("_")[-1]
    
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    # save raw results
    pt_path = os.path.join(results_path,rf'{random_seed}_{rag_model_name}_{repeat_time}+{iteration}_dh-{data_harder}_raw_results_'+now_time+'.pt')
    png_path = os.path.join(results_path,rf'{random_seed}_{rag_model_name}_{repeat_time}+{iteration}_dh-{data_harder}_'+now_time+'.png')
    torch.save(resultses,pt_path)
    logger.info(f"raw results saved to {pt_path}")
    plot_resultses(resultses,batch_size,exp_name,png_path)
    logger.info(f"plot saved to {png_path}")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_time',type=int,default=1)
    parser.add_argument('--iteration',type=int,default=20)
    parser.add_argument('--data_name',type=str,default='suzuki')
    # parser.add_argument('--exp_name',type=str,default=['all','rag', 'sci'])
    parser.add_argument('--exp_name',type=str,default=['Baseline','all','rag', 'sci'])
    parser.add_argument('--batch_size',type=int,default=5)
    args = parser.parse_args()

    # set seed
    random_seed = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
    # random_seed = [42]

    # data_harder
    data_harder_list = [30, 40, 50]
    # data_harder_list = [30]

    use_diverse_sample = True

    start_time = time.time()
    for dh in data_harder_list:
        for seed in random_seed:
            setup_seed(seed)
            main(args, seed, dh, use_diverse_sample=use_diverse_sample)
            
    end_time = time.time()
    logger.success(f"total time: {end_time-start_time}")