import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from data_provider.data_factory import data_provider
from task_modules.ltf_module import LTFModule
from configs.ltf_config import *
from model.utils import get_csv_logger

def ltf_experiment(config, gpus , current_epoch):
    pl.seed_everything(2024)

    _, train_dl = data_provider(config, "train")
    _, val_dl = data_provider(config, "val")
    _, test_dl = data_provider(config, "test")

    model = LTFModule(config)

    monitor_metric = "val_mse"
    callbacks = []
    ckpt_callback = ModelCheckpoint(monitor=monitor_metric,
                                    save_top_k=1,
                                    mode="min")
    callbacks.append(ckpt_callback)
    es_callback = EarlyStopping(monitor=monitor_metric,
                                mode="min",
                                patience=10)
    callbacks.append(es_callback)

    logger = get_csv_logger("logs/ltf",
                            name=f"{config.name}_{config.pred_len}")

    trainer = pl.Trainer(devices=gpus,
                         accelerator="gpu",
                         precision=32,
                         callbacks=callbacks,
                         logger=logger,
                         max_epochs=current_epoch,
                         gradient_clip_val=config.grad_clip_val, 
                         enable_progress_bar=True, 
                         enable_model_summary=True,
                         strategy='ddp_find_unused_parameters_true')

    trainer.fit(model, train_dl, val_dl)
    zz = trainer.test(model, test_dl)
    print(zz[0]['test_mse'], zz[0]['test_mae'])



def run_ltf(args):

    dataset_dict = {
        'exchange_96': Exchange_LTFConfig_96,
        'exchange_192': Exchange_LTFConfig_192,
        'exchange_336': Exchange_LTFConfig_336,
        'exchange_720': Exchange_LTFConfig_720,

        'etth1_96': ETTh1_LTFConfig_96,
        'etth1_192': ETTh1_LTFConfig_192,
        'etth1_336': ETTh1_LTFConfig_336,
        'etth1_720': ETTh1_LTFConfig_720,

        'etth2_96': ETTh2_LTFConfig_96,
        'etth2_192': ETTh2_LTFConfig_192,
        'etth2_336': ETTh2_LTFConfig_336,
        'etth2_720': ETTh2_LTFConfig_720,

        'ettm1_96': ETTm1_LTFConfig_96,
        'ettm1_192': ETTm1_LTFConfig_192,
        'ettm1_336': ETTm1_LTFConfig_336,
        'ettm1_720': ETTm1_LTFConfig_720,

        'ettm2_96': ETTm2_LTFConfig_96,
        'ettm2_192': ETTm2_LTFConfig_192,
        'ettm2_336': ETTm2_LTFConfig_336,
        'ettm2_720': ETTm2_LTFConfig_720,


        'weather_96': Weather_LTFConfig_96,
        'weather_192': Weather_LTFConfig_192,
        'weather_336': Weather_LTFConfig_336,
        'weather_720': Weather_LTFConfig_720,
    }

    dataset_args = set([str.lower(d) for d in args.dataset])
    datasets = []

    pred_len_args = set([str.lower(p) for p in args.pred_len])
    pred_lens = []

    if "all" in pred_len_args:
        pred_lens = [96, 192, 336, 720]
    else:
        pred_lens += [int(d) for d in pred_len_args]

    if "all" in dataset_args:
        datasets = list(dataset_dict.values())
    else:
        for pred in pred_lens:
            datasets += [dataset_dict[d + '_' + str(pred)] for d in dataset_args]



    for dataset in datasets:
        for pred_len in pred_lens:
                config_ = dataset(pred_len)
                ltf_experiment(config_, args.gpus , 200)

