#!/usr/bin/env python
# -*- coding=utf8 -*-

import numpy as np
from src.utils.utils_templates import load_json
from src.utils.wandb_helpers import wandb_wrapper

def obtain_llm_bo_object(bo_type):
    # if bo_type == 'llambo':
    #     from src.algos.baselines.llambo.bo_llambo import bo_llambo as bo_llm

    if bo_type == 'lapeft':
        from src.algos.baselines.lapeft.bo_lapeft import bo_lapeft as bo_llm
    elif bo_type == 'llmat':
        from src.algos.LLMAT.LLMAT_bo import LLMAT_bo as bo_llm
    else:
        pass
    return bo_llm


def run_mat(args):
    from benchmarks.MAT.material_benchmarks import MATBench
    list_f_model = args.benchmarks[args.benchmark]['feature_models']
    list_data = args.datasets  # not data id???
    for data_name in list_data:
        for f_model in list_f_model:
            if args.finetuning and f_model == "fingerprints":
                continue
            print("feature_model: ", f_model)
            iupac = args.iupac
            if data_name not in ["redox-mer", "solvation"]:
                assert iupac == False

            mat_bench = MATBench(
                data_name=data_name,
                run_subset_only=args.run_subset_only,
                feature_type=f_model,
                finetuning=args.finetuning,
                iupac=iupac,
                prompt_type=args.prompt_type,
                randseed=args.seed,
                clustering_type=args.clustering_type,
                feature_reduction=args.feature_reduction,
            )
            # run BO
            #all_final_y = []
            #all_metrics_pd = []

            if args.algorithm in ['llmat', 'lapeft']:
                llm_bo = obtain_llm_bo_object(args.algorithm)
                wandb = wandb_wrapper(args, run=args.seed, data_name=data_name, model_name=f_model)
                final_y, all_metrics = llm_bo(args, mat_bench, wandb)
            else:
                raise (NotImplementedError)
            #all_final_y += [np.array(final_y)]
            #all_metrics_pd += [all_metrics]
            #return all_final_y, all_metrics_pd
