import argparse

import lightning
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

import tpp
from data import EasyTPPDataModule


def get_model(
    dim: int,
    hidden_dim: int,
    num_components: int,
    encoder_name: str,
    decoder_name: str,
) -> dict:
    if encoder_name == 'gru':
        encoder = tpp.GRUEncoder
    elif encoder_name == 'categoricalgru':
        encoder = tpp.GRUEncoderCategorical
    elif encoder_name == 'cnn':
        encoder = tpp.CNNEncoder
    elif encoder_name == 'transformer':
        encoder = tpp.TransformerEncoder
    else:
        raise ValueError(f"Unknown encoder: {encoder_name}")

    if decoder_name == 'lognormmix':
        decoder = tpp.MixtureLogNormalDecoder
    elif decoder_name == 'lognorm':
        decoder = tpp.LogNormalDecoder
    elif decoder_name == 'weibull':
        decoder = tpp.WeibullDecoder
    elif decoder_name == 'exponential':
        decoder = tpp.ExponentialDecoder
    elif decoder_name == 'weibullmix':
        decoder = tpp.WeibullMixtureDecoder
    elif decoder_name == 'categorical':
        decoder = tpp.CategoricalDecoder
    else:
        raise ValueError(f"Unknown decoder: {decoder_name}")

    module_kwargs = dict(
        encoder=encoder(num_classes=dim, hidden_dim=hidden_dim),
        decoder=decoder(num_classes=dim, hidden_dim=hidden_dim, num_components=num_components),
        lr=3e-4,
    )
    return module_kwargs


def train_model(
    dataset: str,
    data_dir: str,
    hidden_dim: int,
    num_components: int,
    encoder_name: str,
    decoder_name: str,
    max_epochs: int,
    gpu: int,
) -> tuple[EasyTPPDataModule, lightning.LightningModule]:
    datamodule = EasyTPPDataModule(data_dir=data_dir, name=dataset, batch_size=64)

    save_dir = f'experiments/{dataset}.{encoder_name}.{decoder_name}.{hidden_dim}.{num_components}'

    module_kwargs = get_model(
        dim=datamodule.dim,
        hidden_dim=hidden_dim,
        num_components=num_components,
        encoder_name=encoder_name,
        decoder_name=decoder_name,
    )
    module = tpp.TPPModule(**module_kwargs)

    early_stop = EarlyStopping(
        monitor='val_loss',
        patience=10,
        mode='min',
    )
    checkpoint_cb = ModelCheckpoint(
        dirpath=save_dir,
        filename='model',
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        enable_version_counter=False,
    )

    trainer = lightning.Trainer(
        max_epochs=max_epochs,
        devices=[gpu],
        callbacks=[early_stop, checkpoint_cb],
        gradient_clip_val=5.0,
        logger=CSVLogger(save_dir=save_dir),
    )
    trainer.fit(module, datamodule)

    module = tpp.TPPModule.load_from_checkpoint(checkpoint_cb.best_model_path, **module_kwargs)
    module.eval()

    return datamodule, module


if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('--name', type=str)
    args.add_argument('--data_dir', type=str, default='../data')
    args.add_argument('--hidden_dim', type=int, default=256)
    args.add_argument('--num_components', type=int, default=8)
    args.add_argument('--encoder_name', type=str)
    args.add_argument('--decoder_name', type=str)
    args.add_argument('--max_epochs', type=int, default=1000)
    args.add_argument('--gpu', type=int, default=1)
    args = args.parse_args()

    dm, module = train_model(
        dataset=args.name,
        data_dir=args.data_dir,
        hidden_dim=args.hidden_dim,
        num_components=args.num_components,
        encoder_name=args.encoder_name,
        decoder_name=args.decoder_name,
        max_epochs=args.max_epochs,
        gpu=args.gpu,
    )
