
import sys
sys.path.append('/fs-computility/ai4phys/aizhehong/ChemBOMAS/Tongji_WetExp/v1.0_ChemBOMAS_WetExp_TongJi/mas/code')
from data.bodata import BoData,Sample
from sampler.bo import BOSampler
from mcts.tree import MCTS
from utils.plot import plot_resultses
import torch
import time
import json
import argparse 
import pandas as pd
import copy

pred_file = "yields_4rounds_BO_56.pt"
wet_exp_result_file = "4rounds_BO_250421.csv"
uncompleted_exp_file = "no_uncompleted_exp.csv"
expert_partition_file = "new_tongji_partition.json"
search_space_file = "v2_searchspeace.csv"
designed_exp_file = "4th_BO_design_20_0.01_24.csv"

directory_path = f'/fs-computility/ai4phys/aizhehong/ChemBOMAS/Tongji_WetExp/v1.0_ChemBOMAS_WetExp_TongJi/mas/data'
Tongji_PATH = f'{directory_path}/exp/tongji/formal/{wet_exp_result_file}'
Tongji_Pred = f'{directory_path}/pred_value/tongji/formal/{pred_file}'
Tongji_Expert_Partition = f'{directory_path}/partition/tongji_formal/{expert_partition_file}'
Tongji_Uncompleted = f'{directory_path}/exp/tongji/{uncompleted_exp_file}'
Tongji_Search_Space = f'{directory_path}/exp/tongji/formal/{search_space_file}'
# Tongji_Partition_Order = ['Catalyst','Ligand','Base', 'Solvent', 'Water', 'Temperature']
Tongji_Partition_Order = ['Catalyst','Ligand', 'Base', 'Solvent', 'Temperature', 'Water']
Tongji_Design_Exp = f'{directory_path}/exp_design/tongji/candidate/{designed_exp_file}'
# Tongji_Partition_Order = ['Water', 'Temperature']
def return_dataset(data_name:str):
    if data_name == 'tongji':
        dataset,mask = BoData.read_tongji_exp_data(Tongji_PATH,Tongji_Search_Space,Tongji_Uncompleted)
        dataset = BoData(dataset)
    else:
        raise ValueError(f"Invalid data name: {data_name}")
    return dataset,mask
def return_expert_partition(data_name:str):
    if data_name == 'tongji':
        expert_partition = json.load(open(Tongji_Expert_Partition))
        order = Tongji_Partition_Order
        # TODO:
        # 临时将Water和Temperature字符化吧
        expert_partition['Water'] = [[str(item) for item in sublist] for sublist in expert_partition['Water']]
        expert_partition['Temperature'] = [[str(item) for item in sublist] for sublist in expert_partition['Temperature']]
        print(f'Expert partition: {expert_partition}\norder:{order}')
    else:
        raise ValueError(f"Invalid data name: {data_name}")
    return expert_partition,order


def run_all(data_name:str):
  


    dataset,mask = return_dataset(data_name)
    print(f"Running {data_name} experiment")
    dataset.load_data_prediction(Tongji_Pred)
    expert_partition,order = return_expert_partition(data_name)

    mcts = MCTS(dataset=dataset,sampler=BOSampler(dataset))
    next_batch = mcts.real_exp(expert_partition,order,mask)
        
    return next_batch

        
def main(args):
    data_name = args.data_name
    resultses = []
    next_batch = run_all(data_name)
    
    # 去重
    unique_batch = [
        dict(t) for t in {frozenset(sample.items()) for sample in next_batch}
    ]
    # 字典键重新排序
    desired_order = ['Catalyst', 'Ligand', 'Solvent', 'Base', 'Water', 'Temperature']
    sorted_unique_batch = [
        {key: sample[key] for key in desired_order} for sample in unique_batch
    ]
    
    print(sorted_unique_batch)
    column_names = sorted_unique_batch[0].keys()
    df = pd.DataFrame(sorted_unique_batch, columns=column_names)
    df.to_csv(Tongji_Design_Exp, index=False)
    


    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--data_name',type=str,default='tongji')

    args = parser.parse_args()
    main(args)
