import os
import sys
import torch

# preamble needed for cluster
module_path = os.path.abspath(os.getcwd())
if module_path not in sys.path:
    sys.path.append(module_path)

from src.utils.utils_training_loop import *
from src.utils.utils_dataset import pick_dataset
from src.utils.utils_models import pick_model
from src.utils.utils_generic import make_dir
import src.models.model_callbacks as cbk


def core_test(seed, model, dataset, features, src_data, out_data, horizon=None, win_back=None, win_forward=None, target_dataset_meta=cst.DatasetFamily.FI):
    cf: Configuration = Configuration(is_test=True)
    cf.SEED = seed

    set_seeds(cf)

    # cf.IS_TEST_ONLY = True # changed config so that it takes is_test
    cf.CHOSEN_FEATURES = features
    # cf.CHOSEN_DATASET = dataset 
    if win_back is not None:
        cf.HYPER_PARAMETERS[cst.LearningHyperParameter.BACKWARD_WINDOW] = win_back.value
        cf.HYPER_PARAMETERS[cst.LearningHyperParameter.FORWARD_WINDOW] = win_forward.value

    if horizon is not None:
        cf.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON] = horizon.value
    if dataset == 'FI':
        cf.CHOSEN_DATASET = cst.DatasetFamily.FI
        cf.CHOSEN_PERIOD = cst.Periods.FI
        cf.CHOSEN_STOCKS[cst.STK_OPEN.TRAIN] = cst.Stocks.FI
        cf.CHOSEN_STOCKS[cst.STK_OPEN.TEST] = cst.Stocks.FI
    elif dataset == 'CHF':
        cf.CHOSEN_DATASET = cst.DatasetFamily.CHF
        cf.CHOSEN_PERIOD = cst.Periods.CHF
        cf.CHOSEN_STOCKS[cst.STK_OPEN.TRAIN] = cst.Stocks.CHF
        cf.CHOSEN_STOCKS[cst.STK_OPEN.TEST] = cst.Stocks.CHF
    else:
        print("no dataset chosen")
        sys.exit()

    cf.IS_WANDB = 0
    cf.IS_TUNE_H_PARAMS = False

    cf.CHOSEN_MODEL = model

    # set to cst.Horizons.K10.value when lobster (unused)

    # OPEN DIR
    dir_name = "model={}-seed={}-trst={}-test={}-data={}-features={}-peri={}-bw={}-fw={}-fiw={}/".format(
        cf.CHOSEN_MODEL.name,
        cf.SEED,
        cf.CHOSEN_STOCKS[cst.STK_OPEN.TRAIN].name,
        cf.CHOSEN_STOCKS[cst.STK_OPEN.TEST].name,
        cf.CHOSEN_DATASET.value,
        cf.CHOSEN_FEATURES,
        cf.CHOSEN_PERIOD.name,
        cf.HYPER_PARAMETERS[cst.LearningHyperParameter.BACKWARD_WINDOW],
        cf.HYPER_PARAMETERS[cst.LearningHyperParameter.FORWARD_WINDOW],
        cf.HYPER_PARAMETERS[cst.LearningHyperParameter.FI_HORIZON],
    )

    # OPEN FILE
    files = [f for f in os.listdir(src_data + dir_name) if not f.startswith('.')]
    print(files)
    assert len(
        files) == 1, 'We expect that in the folder there is only the checkpoint with the highest F1-score:\n{}'.format(
        files)

    print("OK")
    print(seed, model, horizon, dir_name)
    # return

    file_name = files[0]

    # Setting configuration parameters
    model_params = HP_DICT_MODEL[cf.CHOSEN_MODEL].fixed_fi

    for param in cst.LearningHyperParameter:
        if param.value in model_params:
            cf.HYPER_PARAMETERS[param] = model_params[param.value]
            print(model_params[param.value])

    datamodule = pick_dataset(cf)
    model = pick_model(cf, datamodule)

    # Loading the model
    max_predict_batches = 500
    trainer = Trainer(
        accelerator=cst.DEVICE_TYPE, 
        devices=cst.NUM_GPUS, 
        limit_predict_batches=max_predict_batches,
        callbacks = [
            cbk.new_progress_bar()
        ]
    )
    
    checkpoint_file_path = src_data + dir_name + file_name
    print("opening", checkpoint_file_path)
    trainer.test(model=model, datamodule=datamodule, ckpt_path=checkpoint_file_path)

    print("done testing")
    print("start eval ")
    # measure inference time of the best model
    datamodule.batch_size = 2  #
    if cf.CHOSEN_MODEL == cst.Models.TRANSLOB:
        cf.HYPER_PARAMETERS[LearningHyperParameter.BATCH_SIZE] = 2
        model = pick_model(cf, datamodule)
    prediction_time = trainer.predict(model, dataloaders=datamodule.test_dataloader(), ckpt_path=checkpoint_file_path)
    prediction_time_mean, prediction_time_std = np.mean(prediction_time), np.std(prediction_time)
    cf.METRICS_JSON.update_metrics(cf.CHOSEN_STOCKS[cst.STK_OPEN.TRAIN].name, {'inference_mean': prediction_time_mean,
                                                                               'inference_std': prediction_time_std})

    make_dir(out_data)
    print("writing data")
    cf.METRICS_JSON.close(out_data)


# def launch_lobster_test(seeds, model_todo, models_to_avoid, dataset_type, backwards, forwards, src_data, out_data, target_dataset_meta=None):
#     for s in seeds:
#         for model in model_todo:
#             for i in range(len(backwards)):
#                 assert len(backwards) == len(forwards)
#                 km, kp = backwards[i], forwards[i]

#                 if model in set(model_todo) - set(models_to_avoid):
#                     core_test(s, model, dataset_type, src_data, out_data, win_back=km, win_forward=kp, target_dataset_meta=cst.DatasetFamily.LOB)


def launch_FI_CHF_test(seeds, model_todo, models_to_avoid, kset, dataset, features, src_data, out_data):
    for s in seeds:
        for k in kset:
            for model in model_todo:
                if model in set(model_todo) - set(models_to_avoid):
                    if model == cst.Models.DEEPLOBATT:
                        core_test(s, model, dataset, features, src_data, out_data, target_dataset_meta=cst.DatasetFamily.CHF,
                              win_back=k,           # for deeplobatt
                              win_forward=k,        # for deeplobatt
                              horizon=k)
                    else:
                        core_test(s, model, dataset, features, src_data, out_data, target_dataset_meta=cst.DatasetFamily.CHF,
                              horizon=k)


if __name__ == "__main__":
    torch.cuda.empty_cache()

    model_todo = [cst.Models.BINCTABL]
    models_to_avoid = []
    seeds = [5]
    kset = [cst.Horizons.K1]

    args = {
        'dataset' : cst.DatasetFamily.CHF,
        'features' : cst.Features.nonlob,
        'src_data' : "data/saved_models/LOB-CLASSIFIERS-(test2)/",
        'out_data' : "final_jsons/"
    }
    

    launch_FI_CHF_test(seeds, model_todo, models_to_avoid, kset, **args)
