
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import pandas as pd
import copy
Buchwald_PATH = r'/mnt/petrelfs/handong/mas/data/exp/buchwald/processed.csv'
Suzuki_PATH = r'/mnt/petrelfs/handong/mas/data/exp/suzuki/experiment_index.csv'
Tandem_PATH = r'/mnt/petrelfs/handong/mas/data/exp/tandem/processed.csv'
Cpa_PATH = r'/mnt/petrelfs/handong/mas/data/exp/cpa/processed.csv'

Buchwald_Expert_Partition = r'/mnt/petrelfs/handong/mas/data/partition/Buchwald_partition.json'
Suzuki_Expert_Partition = r'/mnt/petrelfs/handong/mas/data/partition/suzuki_partition.json'
Buchwald_Yield_Prediction = r'/mnt/petrelfs/handong/mas/data/pred_value/buchwald/yields.pt'
Suzuki_Yield_Prediction = r'/mnt/petrelfs/handong/mas/data/pred_value/suzuki/yields_suzuki-miyaura_4_llama_regression.pt'
Buchwald_Partition_Order = ['Reactant2','Ligand','Base','Additive']
Suzuki_Partition_Order = ['ligand','base','solvent']
def return_dataset(data_name:str):
    if data_name == 'buchwald':
        dataset = BoData(BoData.read_buchwald_data(Buchwald_PATH))
        dataset.load_data_prediction(Buchwald_Yield_Prediction)
        dataset.use_test_set(Buchwald_Yield_Prediction)
    elif data_name == 'suzuki':
        dataset = BoData(BoData.read_suzuki_exp_data(Suzuki_PATH))
        dataset.load_data_prediction(Suzuki_Yield_Prediction)
        dataset.use_test_set(Suzuki_Yield_Prediction)
    elif data_name == 'tandem':
        dataset = BoData(BoData.read_tandem_exp_data(Tandem_PATH))
    elif data_name == 'cpa':
        dataset = BoData(BoData.read_cpa_exp_data(Cpa_PATH))
        
    else:
        raise ValueError(f"Invalid data name: {data_name}")
    return dataset
def return_expert_partition(data_name:str):
    if data_name == 'buchwald':
        expert_partition = json.load(open(Buchwald_Expert_Partition))
        order = Buchwald_Partition_Order
    elif data_name == 'suzuki':
        expert_partition = json.load(open(Suzuki_Expert_Partition))
        order = Suzuki_Partition_Order
    else:
        raise ValueError(f"Invalid data name: {data_name}")
    return expert_partition,order


def run_baseline(repeat_time:int,iteration:int,data_name:str):

    results = []
    for _ in range(repeat_time):
        print(f"repeat time: {_}")
        result = []
        dataset = return_dataset(data_name)
        dataset.make_data_harder(10)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
        obs = mcts.init_tree(dont_build_tree=True)
        print(f"    random sampling done, max value: {max(obs)}")
        result.extend(obs)
        for i in range(1,iteration):
            obs = mcts.baseline_search()
            result.extend(obs)
            print(f"    iteration {i} done, max value: {max(result)}")
        results.append(result)
    
    return results
def run_ood_baseline(samples,repeat_time:int,iteration:int):
    results = []
    for _ in range(repeat_time):
        print(f"repeat time: {_}")
        result = []
        dataset = BoData(copy.deepcopy(samples))
        #dataset.make_data_harder(30)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
        obs = mcts.init_tree(dont_build_tree=True)
        print(f"    random sampling done, max value: {max(obs)}")
        result.extend(obs)
        for i in range(1,iteration):
            obs = mcts.baseline_search()
            result.extend(obs)
            print(f"    iteration {i} done, max value: {max(result)}")
        results.append(result)
    
    return results
def run_ood_all(samples,repeat_time:int,iteration:int):
    results = []
    for _ in range(repeat_time):
        print(f"repeat time: {_}")
        result = []
        dataset = BoData(copy.deepcopy(samples))
        dataset.make_data_harder(30)
        print(f"    max value: {dataset.max_value}")
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
        expert_partition,order = return_expert_partition('suzuki')
        obs = mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        
        print(f"    random sampling done, max value: {max(obs)}")
        
        if max(obs) == dataset.max_value:
            print(f"    max value reached")
            results.append(obs)
            continue
        result.extend(obs)
        for i in range(1,iteration):
            obs = mcts.search(pseudo_label=True)
            result.extend(obs)
            print(f"    iteration {i} done, max value: {max(result)}")
            if max(obs) == dataset.max_value:
                print(f"    max value reached")
                break
        results.append(result)
    
    return results
def run_pseudo_label(repeat_time:int,iteration:int,data_name:str):
    results = []
    for _ in range(repeat_time):
        print(f"repeat time: {_}")
        result = []
        dataset = return_dataset(data_name)
        
        #dataset.make_data_harder(50)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
        obs = mcts.init_tree(dont_build_tree=True)
        print(f"    random sampling done, max value: {max(obs)}")
        result.extend(obs)
        for i in range(1,iteration):
            obs = mcts.baseline_search(pseudo_label=True)
            result.extend(obs)
            print(f"    iteration {i} done, max value: {max(result)}")
        results.append(result)
    return results

def run_expert_partition(repeat_time:int,iteration:int,data_name:str):
    results = []
    for _ in range(repeat_time):
        print(f"repeat time: {_}")
        result = []
        dataset = return_dataset(data_name)
        #dataset.make_data_harder(50)
        mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
        expert_partition,order = return_expert_partition(data_name)
        obs = mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        print(f"    random sampling done, obs: {obs}")
        result.extend(obs)
        for i in range(1,iteration):
            obs = mcts.search()
            result.extend(obs)
            print(f"    iteration {i} done, max value: {max(result)}")
        results.append(result)
    return results

def run_all(repeat_time:int,iteration:int,data_name:str):
    results = []
    for _ in range(repeat_time):
        print(f"repeat time: {_}")
        result = []
        dataset = return_dataset(data_name)
        expert_partition,order = return_expert_partition(data_name)
        dataset.make_data_harder(30)
        print(f"    max value: {dataset.max_value}")
        # mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
        # obs = mcts.pseudo_init_tree(expert_partition,order,pseudo_label=True)
        # print(f"    random sampling done, max value: {max(obs)}")
        # result.extend(obs)
        # for i in range(1,iteration):
        #     obs = mcts.search(pseudo_label=True)
        #     result.extend(obs)
        #     print(f"    iteration {i} done, max value: {max(result)}")
        # results.append(result)
    return results

def run_ood(repeat_time:int=1,iteration:int=5):
    raw_df = pd.read_csv(r'/mnt/petrelfs/handong/mas/data/exp/suzuki/experiment_index.csv')
    category = ['electrophile','nucleophile','catalyst','ligand','base','solvent']
    objective = ['yield']
    prediction = torch.load(r'/mnt/petrelfs/handong/mas/data/pred_value/suzuki/yields_suzuki-miyaura_4_llama_regression.pt',weights_only=False)['pred_yields_by_rxn']
    val_idx = torch.load(r'/mnt/petrelfs/handong/mas/data/pred_value/suzuki/yields_suzuki-miyaura_4_llama_regression.pt',weights_only=False)['val_idx']
    raw_df['yield'] = prediction
    val_df = raw_df.iloc[val_idx]
    name_simples = {}
    for name,group in val_df.groupby(['electrophile','nucleophile']):
        samples = {}
        feats = torch.tensor(pd.get_dummies(group[category]).values,dtype=torch.float64)
        for i, (index,row) in enumerate(group.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value,predict_value=torch.tensor(row['yield']))
            samples[i] = sample
        name_simples[name] = samples
        print(name,len(samples))
    resultses_baseline = []
    resultses = []
    names = []
    for name,samples in name_simples.items():
        names.append(name)
        print(name)
        resultses.append(run_ood_all(samples,repeat_time,iteration))
        #resultses_baseline.append(run_ood_baseline(samples,repeat_time,iteration))
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    torch.save(resultses,r'/mnt/petrelfs/handong/mas/results/ood11'+r'raw_results_'+now_time+'.pt')
    torch.save(resultses_baseline,r'/mnt/petrelfs/handong/mas/results/ood11'+r'raw_results_baseline_'+now_time+'.pt')
    plot_resultses(resultses,5,r'/mnt/petrelfs/handong/mas/results/ood11'+r'_'+now_time+'.png')
        
def main(args):
    data_name = args.data_name
    resultses = []

    repeat_time = 1
    iteration = 1
    exp_name = args.exp_name
    print(exp_name)
    if 'Baseline' in exp_name:
        resultses.append(run_baseline(repeat_time,iteration,data_name))
    if 'only_pseudo_label' in exp_name:
        resultses.append(run_pseudo_label(repeat_time,iteration,data_name))
    if 'only_expert_partition' in exp_name:
        resultses.append(run_expert_partition(repeat_time,iteration,data_name))
    if 'all' in exp_name:
        resultses.append(run_all(repeat_time,iteration,data_name))
    if 'ood' in exp_name:
        run_ood(repeat_time,iteration)
        return
    
    now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    # save raw results
    print(r'/mnt/petrelfs/handong/mas/results/'+data_name+r'/'+r'raw_results_'+now_time+'.pt')
    print(r'/mnt/petrelfs/handong/mas/results/'+data_name+r'/'+r'_'+now_time+'.png')
    torch.save(resultses,r'/mnt/petrelfs/handong/mas/results/'+data_name+r'/'+r'raw_results_'+now_time+'.pt')
    plot_resultses(resultses,5,exp_name,r'/mnt/petrelfs/handong/mas/results/'+data_name+r'/'+r'_'+now_time+'.png')


    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_time',type=int,default=1)
    parser.add_argument('--iteration',type=int,default=10)
    parser.add_argument('--data_name',type=str,default='buchwald')
    parser.add_argument('--exp_name',type=str,default=['Baseline','all'])
    args = parser.parse_args()
    main(args)
