import argparse
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style('white')

import torch

from benchmarks.lcbench import LCBench
from benchmarks.taskset import TaskSet
from benchmarks.pd1 import PD1
from data.meta_test_datasets import META_TEST_DATASET_DICT

from hpo_methods.our_hpo import OurAlgorithm

from tqdm import trange
import json

import warnings
warnings.filterwarnings("ignore")

def main(args):
    # utility
    def U(budget, graph):
        if type(budget) == int and type(graph) == float:
            if budget == 0:
                return 0
            else:
                return -args.decay_factor*budget + graph
                        
        elif type(budget) == np.ndarray and type(graph) == np.ndarray:
            utility = -args.decay_factor*budget + graph
            utility[budget == 0] = 0.
            return utility            
            
        elif type(budget) == torch.Tensor and type(graph) == torch.Tensor:
            utility = -args.decay_factor*budget + graph
            utility[budget == 0] = 0.
            return utility
        
        else:
            raise NotImplementedError
        
    # data
    benchmark_name = args.benchmark_name
    if benchmark_name == 'lcbench':
        benchmark_data_path = os.path.join(args.data_dir, "data_2k.json")
    elif benchmark_name == 'taskset':
        benchmark_data_path = os.path.join(args.data_dir, 'taskset_chosen.json')
    elif benchmark_name == 'pd1':
        benchmark_data_path = os.path.join(args.data_dir, "pd1_preprocessed.json")
    else:
        raise NotImplementedError

    output_dir = os.path.join(
        args.output_dir,
        f'{benchmark_name}',
        args.dataset_name
    )
    os.makedirs(os.path.join(output_dir, args.dataset_name), exist_ok=True)
    benchmark_dict = {
        'lcbench': LCBench,
        'taskset': TaskSet,
        'pd1': PD1
    }
    benchmark = benchmark_dict[benchmark_name](benchmark_data_path, args.dataset_name)
    hp_candidates = benchmark.get_hyperparameter_candidates()

    best_utility = 0.
    for config_id in range(benchmark.nr_hyperparameters):        
        curve = benchmark.get_curve(config_id, benchmark.max_budget)
        for budget_idx, performance in enumerate(curve):
            utility = U(budget_idx+1, performance)
            if utility > best_utility:
                best_utility = utility

    if args.init:
        if benchmark_name == 'lcbench':
            init_index = 1738
        elif benchmark_name == 'taskset':
            init_index = 672
        elif benchmark_name == 'pd1':
            init_index = 39
        else:
            raise NotImplementedError
    else:
        init_index = None
        

    algo = OurAlgorithm(
        init_index=init_index,
        mean=args.mean,
        eps=args.eps,
        U=U,
        y_0=benchmark.get_init_performance(),
        config_ckpt=args.config_ckpt,
        model_ckpt=args.model_ckpt,
        hp_candidates=hp_candidates,        
        max_benchmark_epochs=benchmark.max_budget,
        total_budget=args.budget_limit,
        dataset_name=args.dataset_name,
        output_path=output_dir,
        seed=args.seed,
        benchmark=benchmark
    )    

    incumbent = 0.
    trajectory = []
    utility_trajectory = []
    first_stop_sign = False    
    stop_budget = args.budget_limit
    budget = 0
    while budget < args.budget_limit:
        hp_index, t, stop_sign = algo.suggest()
        if stop_sign and not first_stop_sign:
            first_stop_sign = True
            stop_budget = budget

        for t_ in range(t, t+args.budget_per_step):
            if t_ > benchmark.max_budget:
                break
            
            budget += 1 
            print(hp_index, t_)
            score = benchmark.get_performance(hp_index, t_)
            algo.observe(hp_index, t_, score)

            if score > incumbent:
                incumbent = score
            utility = U(budget, incumbent)
            trajectory.append(incumbent)
            utility_trajectory.append(utility)
            
            print(args.dataset_name, args.decay_factor, budget)

    print(stop_budget)

    utility_trajectory = np.array(utility_trajectory)    
    log_regret_trajectory = np.log(best_utility - utility_trajectory)

    with open(os.path.join(output_dir, "trajectory.json"), "w") as fp:
        json.dump(trajectory, fp)

    with open(os.path.join(output_dir, "utility_trajectory.json"), "w") as fp:
        json.dump(utility_trajectory.tolist(), fp)
  
    with open(os.path.join(output_dir, "stop_budget.json"), "w") as fp:
        json.dump(stop_budget, fp)    
    
    with open(os.path.join(output_dir, "log_regret_trajectory.json"), "w") as fp:
        json.dump(log_regret_trajectory.tolist(), fp)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description='DyHPO experiments.',
    )    
    parser.add_argument(
        '--data_dir',
        type=str,
        default='./data',
    )
    parser.add_argument(
        '--benchmark_name',
        type=str,
        default='lcbench',
    )
    parser.add_argument(
        '--dataset_name',
        type=str,
        default='segment',
    )
    parser.add_argument(
        '--budget_limit',
        type=int,
        default=100,
    )
    parser.add_argument(
        '--budget_per_step',
        type=int,
        default=1,
    )
    parser.add_argument(
        '--decay_factor',
        type=float,
        default=0.0,
    )
    parser.add_argument(
        '--eps',
        type=float,
        default=0.5,
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default='./output',
    )
    parser.add_argument(
        '--model_ckpt',
        type=str,
        default=None,
    )
    parser.add_argument(
        '--config_ckpt',
        type=str,
        default=None,
    )   
    parser.add_argument(
        '--mean',
        action="store_true",
    )  
    parser.add_argument(
        '--seed',
        type=int,
        default=1,
    )
    parser.add_argument(
        '--init',
        action="store_true",
    )
    args = parser.parse_args()

    main(args)