
import sys
import os

module_path = os.path.abspath(os.getcwd())
if module_path not in sys.path:
    sys.path.append(module_path)

import argparse
import random
import numpy as np
import wandb
import traceback
import socket
import torch
import time
import psutil
import itertools

# TORCH
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from src.constants import LearningHyperParameter


import src.constants as cst
import src.models.model_callbacks as cbk
from src.config import Configuration

# MODELS
from src.models.mlp.mlp_param_search import HP_MLP, HP_MLP_FI_FIXED, HP_MLP_CHF_FIXED
from src.models.tabl.tabl_param_search import HP_TABL, HP_TABL_FI_FIXED, HP_TABL_CHF_FIXED
from src.models.translob.tlb_param_search import HP_TRANS, HP_TRANS_FI_FIXED, HP_TRANS_CHF_FIXED
from src.models.cnn1.cnn1_param_search import HP_CNN1, HP_CNN1_FI_FIXED, HP_CNN1_CHF_FIXED
from src.models.cnn2.cnn2_param_search import HP_CNN2, HP_CNN2_FI_FIXED, HP_CNN2_CHF_FIXED
from src.models.cnnlstm.cnnlstm_param_search import HP_CNNLSTM, HP_CNNLSTM_FI_FIXED, HP_CNNLSTM_CHF_FIXED
from src.models.dain.dain_param_search import HP_DAIN, HP_DAIN_FI_FIXED, HP_DAIN_CHF_FIXED
from src.models.deeplob.dlb_param_search import HP_DEEP, HP_DEEP_FI_FIXED, HP_DEEP_CHF_FIXED
from src.models.lstm.lstm_param_search import HP_LSTM, HP_LSTM_FI_FIXED, HP_LSTM_CHF_FIXED
from src.models.binctabl.binctabl_param_search import HP_BINTABL, HP_BINTABL_FI_FIXED, HP_BINTABL_CHF_FIXED
from src.models.deeplobatt.dlbatt_param_search import HP_DEEPATT, HP_DEEPATT_FI_FIXED, HP_DEEPATT_CHF_FIXED
from src.models.dla.dla_param_search import HP_DLA, HP_DLA_FI_FIXED, HP_DLA_CHF_FIXED
from src.models.tlonbof.tlonbof_param_search import HP_TLONBoF, HP_TLONBoF_FI_FIXED, HP_TLONBoF_CHF_FIXED

from src.utils.utils_dataset import pick_dataset
from src.utils.utils_models import pick_model
from collections import namedtuple

HPSearchTypes  = namedtuple('HPSearchTypes', ("sweep", "fixed_fi", "fixed_chf"))
# HPSearchTypes2 = namedtuple('HPSearchTypes', ("sweep", "fixed"))

# MAPS every model to 3 dictionaries of parameters:
#
# HPSearchTypes.sweep:     for the hyperparameters sweep
# HPSearchTypes.fixed_fi:  fixed parameters for the FI dataset
# HPSearchTypes.fixed_ine: fixed parameters for the INE dataset of CHF 2023
#
HP_DICT_MODEL = {
    cst.Models.MLP:  HPSearchTypes(HP_MLP, HP_MLP_FI_FIXED, HP_MLP_CHF_FIXED),
    cst.Models.CNN1: HPSearchTypes(HP_CNN1, HP_CNN1_FI_FIXED, HP_CNN1_CHF_FIXED),
    cst.Models.CNN2: HPSearchTypes(HP_CNN2, HP_CNN2_FI_FIXED, HP_CNN2_CHF_FIXED),
    cst.Models.LSTM: HPSearchTypes(HP_LSTM, HP_LSTM_FI_FIXED, HP_LSTM_CHF_FIXED),
    cst.Models.CNNLSTM: HPSearchTypes(HP_CNNLSTM, HP_CNNLSTM_FI_FIXED, HP_CNNLSTM_CHF_FIXED),
    cst.Models.DAIN: HPSearchTypes(HP_DAIN, HP_DAIN_FI_FIXED, HP_DAIN_CHF_FIXED),
    cst.Models.DEEPLOB: HPSearchTypes(HP_DEEP, HP_DEEP_FI_FIXED, HP_DEEP_CHF_FIXED),
    cst.Models.TRANSLOB: HPSearchTypes(HP_TRANS, HP_TRANS_FI_FIXED, HP_TRANS_CHF_FIXED),
    cst.Models.CTABL: HPSearchTypes(HP_TABL, HP_TABL_FI_FIXED, HP_TABL_CHF_FIXED),
    cst.Models.BINCTABL: HPSearchTypes(HP_BINTABL, HP_BINTABL_FI_FIXED, HP_BINTABL_CHF_FIXED),
    cst.Models.DEEPLOBATT: HPSearchTypes(HP_DEEPATT, HP_DEEPATT_FI_FIXED, HP_DEEPATT_CHF_FIXED),
    cst.Models.DLA: HPSearchTypes(HP_DLA, HP_DLA_FI_FIXED, HP_DLA_CHF_FIXED),
    cst.Models.TLONBoF: HPSearchTypes(HP_TLONBoF, HP_TLONBoF_FI_FIXED, HP_TLONBoF_CHF_FIXED),
}


def __run_training_loop(config: Configuration, model_params=None):
    """ Set the model parameters and lunch the training loop. """

    def core(config, model_params):

        # if no hyperparameter tuning must be done, use the fixed parameters
        if not config.IS_TUNE_H_PARAMS:
            assert model_params is None

            if config.CHOSEN_DATASET in [cst.DatasetFamily.FI]:
                print("chose fixed fi params")
                model_params = HP_DICT_MODEL[config.CHOSEN_MODEL].fixed_fi

            elif config.CHOSEN_DATASET in [cst.DatasetFamily.CHF]:
                print("chose fixed chf params")
                model_params = HP_DICT_MODEL[config.CHOSEN_MODEL].fixed_chf


        # SET hyperparameter in the config object
        for param in cst.LearningHyperParameter:
            if param.value in model_params:
                config.HYPER_PARAMETERS[param] = model_params[param.value]

        print("Set model parameters!!!", model_params)

        config.dynamic_config_setup()


        # vvv TRAINING LOOP vvv
        print('picking dataset')
        data_module = pick_dataset(config)    # load the data
        print("input shape: {}, batches: {}".format(data_module.x_shape, data_module.batch_size))
        print('picking model')
        nn = pick_model(config, data_module)  # load the model
        print(f"Model: {config.CHOSEN_MODEL}, lr: {config.HYPER_PARAMETERS[LearningHyperParameter.LEARNING_RATE]}, optimizer: {config.HYPER_PARAMETERS[LearningHyperParameter.OPTIMIZER]}, batch size: {config.HYPER_PARAMETERS[LearningHyperParameter.BATCH_SIZE]}, snapshots: {config.HYPER_PARAMETERS[LearningHyperParameter.NUM_SNAPSHOTS]}")
        
        if config.IS_WANDB:
            config.WANDB_INSTANCE.log({"learning_rate": config.HYPER_PARAMETERS[LearningHyperParameter.LEARNING_RATE]})
            config.WANDB_INSTANCE.log({"batch_size": config.HYPER_PARAMETERS[LearningHyperParameter.BATCH_SIZE]})
            config.WANDB_INSTANCE.log({"optimizer": config.HYPER_PARAMETERS[LearningHyperParameter.OPTIMIZER]})
            config.WANDB_INSTANCE.log({"num_snapshots": config.HYPER_PARAMETERS[LearningHyperParameter.NUM_SNAPSHOTS]})

        torch.cuda.empty_cache()
        trainer = Trainer(
            accelerator=cst.DEVICE_TYPE,
            devices=cst.NUM_GPUS,
            check_val_every_n_epoch=config.VALIDATE_EVERY,
            max_epochs=config.HYPER_PARAMETERS[cst.LearningHyperParameter.EPOCHS_UB],
            callbacks=[
                cbk.callback_save_model(config, config.WANDB_RUN_NAME),
                cbk.early_stopping(config),
                cbk.new_progress_bar()
            ],
            # strategy="ddp"  # needed to prevent pickling
        )
        # Getting % usage of virtual_memory ( 3rd field)
        print('RAM memory % used:', psutil.virtual_memory()[2])
        # Getting usage of virtual_memory in GB ( 4th field)
        print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)

        start = time.time()

        # TRAINING STEP
        print("training")
        if config.RESUME_TRAINING and os.path.exists(cst.DIR_SAVED_MODEL + config.WANDB_SWEEP_NAME):
            print("checkpoint loading...")
            files = [f for f in os.listdir(cst.DIR_SAVED_MODEL + config.WANDB_SWEEP_NAME) if not f.startswith('.')]
            print(files)
            if len(files) >= 1:
                print("more than 1 checkpoint file...")
            nn.testing_mode = cst.ModelSteps.VALIDATION_EPOCH
            trainer.fit(nn, data_module, ckpt_path="last") # if resume training
        else:
            print("no checkpoints")
            nn.testing_mode = cst.ModelSteps.VALIDATION_EPOCH
            trainer.fit(nn, data_module) # if new training run
      
        # FINAL VALIDATION STEP
        print("validating")
        nn.testing_mode = cst.ModelSteps.VALIDATION_MODEL
        trainer.validate(nn, dataloaders=data_module.val_dataloader(), ckpt_path="best")

        # TEST STEP
        if config.IS_TUNE_H_PARAMS:
            print("no testing because hpo")
        else:
            print("testing")
            nn.testing_mode = cst.ModelSteps.TESTING
            trainer.test(nn, dataloaders=data_module.test_dataloader(), ckpt_path="best")
            config.METRICS_JSON.close(cst.DIR_EXPERIMENTS)

        t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start))
        print(f"Execution time for 1 training pass for model {config.CHOSEN_MODEL}, \
horizon {config.HYPER_PARAMETERS[LearningHyperParameter.FI_HORIZON]}, \
learning rate {config.HYPER_PARAMETERS[LearningHyperParameter.LEARNING_RATE]}, \
batch size {config.HYPER_PARAMETERS[LearningHyperParameter.BATCH_SIZE]}: {str(t)}")

        # import gc
        # gc.collect()
        # torch.cuda.empty_cache()

    try:
        core(config, model_params)
    except:
        print("The following error was raised:")
        print(traceback.print_exc(), file=sys.stderr)
        exit(1)


def run(config: Configuration):
    """ Build a WANDB sweep from a configuration object. """

    def _wandb_exe(config: Configuration):
        """ LOG on WANDB console. """

        # run_name = None
        # if not config.IS_TUNE_H_PARAMS:
        #     config.dynamic_config_setup()
        #     run_name = config.WANDB_SWEEP_NAME

        with wandb.init(project=cst.PROJECT_NAME, name=None) as wandb_instance:
            # log simulation details in WANDB console

            wandb_instance.log_code("src/")
            wandb_instance.log({"model": config.CHOSEN_MODEL.name})
            wandb_instance.log({"seed": config.SEED})
            wandb_instance.log({"features": config.CHOSEN_FEATURES.name})
            wandb_instance.log({"alpha": cst.ALPHA})
            wandb_instance.log({"fi-k": config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON]})

            config.WANDB_RUN_NAME = wandb_instance.name
            config.WANDB_INSTANCE = wandb_instance

            params_dict = wandb_instance.config  # chosen parameters from WANDB search
            if not config.IS_TUNE_H_PARAMS:
                params_dict = None

            __run_training_loop(config, params_dict)


    # 🐝 STEP: initialize sweep by passing in cf
    config.dynamic_config_setup()  # initializes the simulation

    if config.IS_TUNE_H_PARAMS:
        print("sweep")
        if config.IS_WANDB:
            sweep_name = "model={}-data={}-features={}-horizon={}-seed={}".format(
                config.CHOSEN_MODEL, config.CHOSEN_DATASET, config.CHOSEN_FEATURES,
                config.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON], config.SEED
            )
            sweep_id = wandb.sweep(
                sweep={
                    'command': ["${env}", "python3", "${program}", "${args}"],
                    'program': "src/utils_training_loop.py",
                    'name':    sweep_name,
                    'method':  config.SWEEP_METHOD,
                    'metric':  config.SWEEP_METRIC,
                    'parameters': {
                        **HP_DICT_MODEL[config.CHOSEN_MODEL].sweep
                    }
                },
                entity='step_cheng',    # specify desired entity
                project=cst.PROJECT_NAME
            )
            max_count = 1
            for key in HP_DICT_MODEL[config.CHOSEN_MODEL].sweep:
                max_count *= len(HP_DICT_MODEL[config.CHOSEN_MODEL].sweep[key]['values'])
            print(f"max runs: {max_count}")
            wandb.agent(sweep_id, function=lambda: _wandb_exe(config), count=max_count)
        else:
            print('no wandb')
            params = HP_DICT_MODEL[config.CHOSEN_MODEL].sweep
            keys = params.keys()
            values = [params[key]['values'] for key in keys]

            for combination in itertools.product(*values):
                hyperparams = dict(zip(keys,combination))
                # print(hyperparams)
                __run_training_loop(config, hyperparams)

    else:
        # NO SWEEP
        print("no sweep")
        if config.IS_WANDB:
            _wandb_exe(config)
        else:
            __run_training_loop(config, None)


def set_seeds(config: Configuration):
    """ Sets the random seed to all the random generators. """
    seed_everything(config.SEED)
    np.random.seed(config.SEED)
    random.seed(config.SEED)
    config.RANDOM_GEN_DATASET = np.random.RandomState(config.SEED)


def experiment_preamble(run_name_prefix, servers):
    """
    Returns the run_name_prefix, server name, server id and number of servers.
    This function is used to run the same code on different machines identified by their hostname.
    Each machine will execute the model defined in the execution plan.
    """

    if run_name_prefix is None:
        parser = argparse.ArgumentParser(description='Stock Price Experiment FI:')
        parser.add_argument('-run_name_prefix', '--run_name_prefix', default=None)
        args = vars(parser.parse_args())
        run_name_prefix = args["run_name_prefix"]

    hostname = socket.gethostname()

    n_servers = len(cst.server2hostname) if servers is None else len(servers)
    servers = cst.server2hostname if servers is None else servers  # list of server
    servers_hostname = [cst.server2hostname[s] for s in servers]   # list of hostnames

    if hostname in servers_hostname:
        server_name = cst.hostname2server[hostname]
        server_id = servers_hostname.index(hostname)

    elif cst.Servers.ANY in servers:
        server_name = cst.Servers.ANY
        server_id = 0

    else:
        raise "This SERVER is not handled for the experiment."

    print("Running on server", server_name.name)
    return run_name_prefix, server_name, server_id, n_servers
