
import mptp.utils.utils_training_loop as tlu
import mptp.constants as cst
import src.utils.utils_generic as gen
import sys
import torch
import os
import time
import argparse

from src.config import Configuration
from src.constants import LearningHyperParameter

DEFAULT_SEEDS = set(range(500, 505))
DEFAULT_HORIZONS = set(cst.Horizons)


def run_experiment(execution_plan, dataset, features, wb=False, hpo=False, resume=False, level=-1, alpha=0.00002, run_name_prefix="CHF-EXPERIMENTS"):
    """ Sets the experiment configuration object based on the execution plan and runs the simulation. """

    servers = [server for server in execution_plan.keys()]

    run_name_prefix, server_name, _, _ = tlu.experiment_preamble(run_name_prefix, servers)
    lunches_server = execution_plan[server_name]  # the execution plan for this machine

    # iterates over the models and the plans assigned to this machine (i.e, seeds and horizons)
    for mod, plan in lunches_server:
        seeds = plan['seed']
        seeds = DEFAULT_SEEDS if seeds == 'all' else seeds

        for see in seeds:
            horizons = plan['k']
            horizons = DEFAULT_HORIZONS if horizons == 'all' else horizons

            for k in horizons:
                print("Running {} experiment on {}, with K={}, features={}".format(dataset, mod, k, features))

                try:
                    # creates the configuration object to be used thought all the simulation
                    cf: Configuration = Configuration(run_name_prefix)
                    cf.SEED = see
                    cf.HYPER_PARAMETERS[LearningHyperParameter.FI_HORIZON] = k.value
                    cf.RESUME_TRAINING = resume

                    if mod == cst.Models.DEEPLOBATT:
                        cf.HYPER_PARAMETERS[LearningHyperParameter.FORWARD_WINDOW] = k.value
                        cf.HYPER_PARAMETERS[LearningHyperParameter.BACKWARD_WINDOW] = k.value

                    tlu.set_seeds(cf)

                    cf.CHOSEN_DATASET = dataset
                    cf.LOB_LEVELS = level
                    cf.ALPHA = alpha
                    cf.CHOSEN_FEATURES = features
                    cf.CHOSEN_MODEL = mod
                    cf.IS_TUNE_H_PARAMS = hpo # True for sweep, False for fixed     

                    if cf.IS_TUNE_H_PARAMS:
                        print("sweeping")
                    
                    cf.IS_WANDB = wb
                    if wb:
                        os.environ["WANDB_MODE"] = "online"
                    else:
                        os.environ["WANDB_MODE"] = "offline"

                    start_time = time.time()

                    # run the experiemnt
                    tlu.run(cf)
                    
                    t = time.strftime("%H:%M:%S", time.gmtime(time.time()-start_time))
                    print("Execution time for model {}, dataset {}, features {}, horizon {}, seed {}: {}".format(mod, dataset, features, k.value, see, str(t)))

                except KeyboardInterrupt:
                    print("There was a problem running on", server_name.name, "{} experiment on {}, with K={}".format(dataset, mod, k))
                    sys.exit()


if __name__ == '__main__':

    # EXE_PLAN is an execution plan. It is a list of all the desired sequential lunches to do, having the format:
    # KEY: server name in src.constants.Servers enum, VALUE: list of tuples
    # s.t. (tuple[0] = model name in src.constants.Horizons enum,
    #       tuple[1] = dictionary s.t. {'k':    'all' or list of src.constants.Horizons,
    #                                   'seed': 'all' or list of integer representing the random seed })
    
    if torch.cuda.is_available():
        print("GPUs available: {}".format(torch.cuda.device_count()))
        torch.cuda.empty_cache()
    else:
        print("GPU not available")

    parser = argparse.ArgumentParser(description='train models')
    parser.add_argument('models', type=str, nargs='+', help='list of models to train')
    parser.add_argument('--horizons', type=gen.list_of_int, default=[1,2,3,5,10], help='select from horizons 1,2,3,5,10')
    parser.add_argument('--dataset', type=str, default='CHF', help='either dataset FI or CHF')
    parser.add_argument('--features', type=str, default='basic', help='either basic, insens, or all features')
    parser.add_argument('--seeds', type=gen.list_of_int, default=[1], help='randomization seed')
    parser.add_argument('--save_dir', type=str, default='test', help='randomization seed')
    parser.add_argument('--loblevel', type=int, default=-1, help='number of lob levels, -1 for all levels')
    parser.add_argument('--alpha', type=str, default='0.00002', help='set alpha of dataset')
    parser.add_argument('--hpo', action='store_true', help='do hyperparameter optimization or not')
    parser.add_argument('--resume', action='store_true', help='resume training')
    parser.add_argument('--no_wb', action='store_true', help='do not use wandb')
    parser.set_defaults(hpo=False, resume=False, no_wb=False)
    args = parser.parse_args()

    EXE_PLAN = {
        cst.Servers.ANY: [
            # (cst.Models.DLA, {'k': cst.Horizons, 'seed': [1]})
        ]
    }

    horizons = gen.hor_to_enum(args.horizons)
    for name in args.models:
        model = cst.Models(name)
        EXE_PLAN[cst.Servers.ANY].append((model, {'k': horizons, 'seed': args.seeds}))

    exp_args = {
        'dataset' : cst.DatasetFamily(args.dataset),
        'features' : cst.Features(args.features),
        'run_name_prefix' : args.save_dir,
        'wb' : not args.no_wb,
        'hpo' : args.hpo,
        'resume' : args.resume,
        'level' : args.loblevel,
        'alpha' : args.alpha
    }

    print(EXE_PLAN)
    print(exp_args)

    
    run_experiment(EXE_PLAN, **exp_args)
