import os
import torch
import argparse


import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split


import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.tuner import Tuner

# import random
# L.seed_everything(random.randint(0, 10000))

from src.ts_model import TSHOT
from src.ts_data import TSDataModule


print(f'Cuda {"IS" if torch.cuda.is_available() else "is NOT"} Available')
print('NUM CUDA DEVICES:', torch.cuda.device_count())


def main():
    parser = argparse.ArgumentParser(
        "TSHOT", add_help=False
    )
    parser.add_argument(
        "--name", default='electricity', type=str, help="name of the dataset"
    )
    parser.add_argument(
        "--attention", default=2, type=int, help="attention order"
    )
    parser.add_argument(
        "--linear", default=1, type=int, help="use linear attention"
    )
    parser.add_argument(
        "--pooling", default='mean', type=str, help="pooling function"
    )
    parser.add_argument(
        "--kernel", default=4, type=int, help="patch size"
    )
    parser.add_argument(
        "--hidden", default=1, type=int, help="number of hidden dimensions"
    )
    parser.add_argument(
        "--nhead", default=1, type=int, help="number of heads"
    )
    parser.add_argument(
        "--nblocks", default=1, type=int, help="number of blocks"
    )
    parser.add_argument(
        "--dropout", default=0.1, type=float, help="dropout"
    )
    args = parser.parse_args()

    ## Data
    datamod = TSDataModule(
        data_path='data/timeseries/processed',
        name=args.name,
        split_sizes=[0.7, 0.1, 0.2],
        context_length=96, 
        batch_size=8, 
        prediction_length=720, 
        normalize=True,
        num_workers=1
    )

    ## Model

    attention_ignore_list = None
    if args.attention == 0:
        attention_ignore_list = [1]
    elif args.attention == 1:
        attention_ignore_list = [2]

    
    model = TSHOT(
        d_hidden=args.hidden,
        n_blocks=args.nblocks, 
        n_head=args.nhead, 
        kernel=args.kernel,
        pooling=args.pooling,
        context_length=datamod.context_length,
        normalize=True,
        dropout=args.dropout, 
        use_linear_att=bool(args.linear),
        feature_map='SMReg',
        lr=1e-3,
        weight_decay=0.,
        attention_ignore_list=attention_ignore_list
    )

    ## Trainer
    name = f'pooling={args.pooling}-kernel={args.kernel}-hidden={args.hidden}-heads={args.nhead}-block={args.nblocks}-dropout={args.dropout}'
    proj = 'TSHOT-' + args.name
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"experiments/{proj}/weights", 
        filename=name + "-{epoch:02d}-{val-mae-720:.2f}",
        save_top_k=1, 
        monitor="val-mae-720", 
        mode='min'
    )
    early_stop_callback = EarlyStopping(
        monitor="val-mae-720", min_delta=0.01, patience=50, verbose=False, mode="min"
    )
    trainer = L.Trainer(
        max_epochs=100,
        devices=1,
        accelerator="gpu", 
        num_nodes=1,
        callbacks=[checkpoint_callback, early_stop_callback],
        accumulate_grad_batches=4,
        gradient_clip_val=5.,
        enable_progress_bar=False
    )
    
    ## Experiment
    trainer.fit(model, datamod.train_dataloader(), datamod.val_dataloader())
    trainer.test(ckpt_path='best', dataloaders=datamod.test_dataloader())


if __name__ == '__main__':
    main()

