"""
This module provides functionality for training molecular regression models using various graph neural network (GNN) architectures and positional encoding (PE) functions.

Classes:
    MTransformerTrainer: A PyTorch Lightning module for training a molecular regression model.

Functions:
    process(model, model_name, pe_func, pe_name): Sets up data loaders and trains the model.
    run(model_name, pe_name): CLI command to select model and PE function, and start the training process.
    main(): Entry point for the script.

Usage:
    Run the script with the appropriate command-line options to train a model.

Command-line options:
    -m, --model-name: The name of the model to use. Must be one of ["gcn", "gat", "gatedgcn", "transformer"].
    -p, --pe-name: The name of the positional encoding function to use. Must be one of ["laplacian", "laplacian_abs", "rw", "gape", "nope"].
"""

from typing import Literal
import click
import lightning as pl
import torch
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

from .datasets import MoleculesBatch, setup_data
from .mtransformer import MGAT, MGCN, MGatedGCN, MTransformer
from .pe_generators import (
    laplacian_pe_func,
    laplacian_abs_pe_func,
    rw_pe_func,
    gape_gatedgcn_18_pe_func,
    nope_pe_func,
)


class MTransformerTrainer(pl.LightningModule):
    def __init__(self, model, max_learning_rate=1e-3):
        super().__init__()
        self.model = model
        self.max_learning_rate = max_learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        y_hat = self.model(batch)
        loss = self.criterion(y_hat, batch.targets)
        self.log("train/loss", loss, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1, weight_decay=0.000005)
        scheduler = OneCycleLR(
            optimizer, max_lr=self.max_learning_rate, total_steps=300, pct_start=0.05
        )
        return [optimizer], [scheduler]

    def criterion(self, y_hat, y):
        return torch.nn.functional.l1_loss(y_hat, y)

    def validation_step(self, batch, batch_idx):
        y_hat = self.model(batch)
        loss = self.criterion(y_hat, batch.targets)
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        return loss


def process(model, model_name, pe_func, pe_name) -> None:
    train_dataset = setup_data(
        "ZINC",
        "train",
        pe_name=pe_name,
        pe_func=pe_func,
        pe_dim=32,
    )
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=256,
        shuffle=True,
        collate_fn=MoleculesBatch.collate_fn,
        num_workers=0,
        pin_memory=True,
        # persistent_workers=True,
        # prefetch_factor=2,
    )

    val_dataset = setup_data(
        "ZINC",
        "test",
        pe_name=pe_name,
        pe_func=pe_func,
        pe_dim=32,
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=256,
        shuffle=False,
        collate_fn=MoleculesBatch.collate_fn,
        num_workers=0,
        pin_memory=True,
        # persistent_workers=True,
        # prefetch_factor=2,
    )

    lightning_module = MTransformerTrainer(model, max_learning_rate=1e-4)
    trainer = pl.Trainer(
        max_epochs=300,
        gradient_clip_algorithm="value",
        gradient_clip_val=0.1,
        callbacks=[LearningRateMonitor("epoch")],
        logger=WandbLogger(
            project="ZINC-regression",
            name=f"{model_name} + {pe_name}",
            log_model=True,
        ),
    )
    trainer.fit(
        lightning_module,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader,
    )


@click.command()
@click.option(
    "-m",
    "--model-name",
    required=True,
    type=Literal["gcn", "gat", "gatedgcn", "transformer"],
    help="Model name",
)
@click.option(
    "-p",
    "--pe-name",
    required=True,
    type=Literal["laplacian", "laplacian_abs", "rw", "gape", "nope"],
    help="Positional encoding name",
)
def run(model_name, pe_name):
    model = None
    if model_name == "gcn":
        model = MGCN(
            num_classes=29, pe_dim=32, d_model=254, num_layers=8, dropout=0.0005
        )
    elif model_name == "gat":
        model = MGAT(
            num_classes=29,
            pe_dim=32,
            d_model=248,
            num_layers=8,
            heads=8,
            dropout=0.0005,
        )
    elif model_name == "gatedgcn":
        model = MGatedGCN(
            num_classes=29, pe_dim=32, d_model=128, num_layers=8, dropout=0.0005
        )
    elif model_name == "transformer":
        model = MTransformer(
            num_classes=29,
            pe_dim=32,
            d_model=128,
            nhead=16,
            num_encoder_layers=6,
            dim_feedforward=128,
            dropout=0.0005,
        )

    pe_func = None
    if pe_name == "laplacian":
        pe_func = laplacian_pe_func(32)
    elif pe_name == "laplacian_abs":
        pe_func = laplacian_abs_pe_func(32)
    elif pe_name == "rw":
        pe_func = rw_pe_func(32)
    elif pe_name == "gape":
        pe_func = gape_gatedgcn_18_pe_func(32)
    elif pe_name == "nope":
        pe_func = nope_pe_func(32)

    process(model, model_name, pe_func, pe_name)


def main():
    run()


if __name__ == "__main__":
    main()
