"""
CRaWl experiments on the high school contact network.
"""
import torch
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from torch.nn import Linear, ReLU, Sequential

from scrawl.datasets import SchoolContactsDataModule
from scrawl.layers import PoolingMethod, SimplicialCRaWl
from scrawl.simplicial import SimplicialData

### CONFIGURATION ###
NUM_LAYERS = 4
FEATURE_SIZES = [0]
LOCAL_WINDOW_SIZES = [[8], [8], [8], [8]]
EMBEDDING_SIZES = [
    [128],
    [128],
    [128],
    [128],
]
KERNEL_SIZES = [8]


class SimplicialCRaWlLightning(LightningModule):
    """
    SCRaWl model for the social contact datasets.

    Parameters
    ----------
    num_classes : int
        The number of classes to predict.
    num_layers : int
        The number of SCRaWl layers to use.
    feature_sizes : list[int]
        The feature sizes of the simplicial complexes.
    walk_steps : int | tuple[int, int]
        The number of walk steps to use during training and validation. If a tuple
        is given, the first value is used for training and the second for
        validation.
    local_window_sizes : list[list[int]]
        The local window sizes for each layer.
    embedding_sizes : list[list[int]]
        The embedding sizes for each layer.
    kernel_sizes : list[int]
        The kernel sizes for each layer.
    pooling : PoolingMethod, default=PoolingMethod.MEAN
        The pooling method to use.
    dropout : float, default=0.0
        The dropout probability to use.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    """

    scrawl: SimplicialCRaWl

    def __init__(
        self,
        num_classes: int,
        num_layers: int,
        feature_sizes,
        walk_steps: int | tuple[int, int],
        local_window_sizes,
        embedding_sizes,
        kernel_sizes,
        *,
        pooling: PoolingMethod = PoolingMethod.MEAN,
        dropout=0.0,
        batch_running_stats: bool = True,
    ) -> None:
        """
        SCRaWl model for the social contact datasets.

        Parameters
        ----------
        num_classes : int
            The number of classes to predict.
        num_layers : int
            The number of SCRaWl layers to use.
        feature_sizes : list[int]
            The feature sizes of the simplicial complexes.
        walk_steps : int | tuple[int, int]
            The number of walk steps to use during training and validation. If a tuple
            is given, the first value is used for training and the second for
            validation.
        local_window_sizes : list[list[int]]
            The local window sizes for each layer.
        embedding_sizes : list[list[int]]
            The embedding sizes for each layer.
        kernel_sizes : list[int]
            The kernel sizes for each layer.
        pooling : PoolingMethod, default=PoolingMethod.MEAN
            The pooling method to use.
        dropout : float, default=0.0
            The dropout probability to use.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        """
        super().__init__()
        self.save_hyperparameters()

        if isinstance(walk_steps, tuple):
            self.walk_steps_train, self.walk_step_val = walk_steps
        else:
            self.walk_steps_train = self.walk_step_val = walk_steps

        self.scrawl = SimplicialCRaWl(
            num_layers,
            feature_sizes,
            local_window_sizes,
            embedding_sizes,
            kernel_sizes,
            pooling=pooling,
            dropout=dropout,
            batch_running_stats=batch_running_stats,
        )

        self.output = Sequential(
            Linear(embedding_sizes[-1][0], embedding_sizes[-1][0]),
            ReLU(),
            Linear(embedding_sizes[-1][0], num_classes),
        )

        # metrics
        self.train_acc = torchmetrics.Accuracy(
            "multiclass", num_classes=num_classes, top_k=1
        )
        self.val_acc = torchmetrics.Accuracy(
            "multiclass", num_classes=num_classes, top_k=1
        )

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer and learning rate scheduler.

        Returns
        -------
        OptimizerLRScheduler
            The optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.5, patience=10, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def forward(self, data: SimplicialData, walk_steps):
        """
        Forward pass of the model.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.
        walk_steps : int
            The number of walk steps to use.

        Returns
        -------
        torch.Tensor
            The output of the model.
        """
        input_data = SimplicialData(
            data.domain, dtype=torch.float32, device=data.device
        )
        embedding_data = self.scrawl(input_data, walk_steps)
        return self.output(embedding_data[0])

    def training_step(self, data, batch_idx):
        """
        Training step.

        Parameters
        ----------
        data : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        output = self(data, self.walk_steps_train)

        loss = F.cross_entropy(
            output[data.aux_tensor(0)], data[0][data.aux_tensor(0)].squeeze()
        )
        self.log("train_loss", loss, batch_size=1)

        # metrics
        self.train_acc(
            output[data.aux_tensor(0)], data[0][data.aux_tensor(0)].squeeze()
        )

        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
            batch_size=1,
        )

        return {"loss": loss}

    def validation_step(self, data, batch_idx):
        """
        Validation step.

        Parameters
        ----------
        data : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        output = self(data, self.walk_step_val)

        loss = F.cross_entropy(
            output[~data.aux_tensor(0)], data[0][~data.aux_tensor(0)].squeeze()
        )
        self.log("val_loss", loss, batch_size=1, prog_bar=True)

        # metrics
        self.val_acc(
            output[~data.aux_tensor(0)], data[0][~data.aux_tensor(0)].squeeze()
        )

        self.log("val_acc", self.val_acc, prog_bar=True)

        return {"loss": loss}


contact_high_school = SchoolContactsDataModule(
    "high-school", train_size=0.6, max_rank=1
)
model = SimplicialCRaWlLightning(
    num_classes=contact_high_school.num_classes,
    num_layers=NUM_LAYERS,
    feature_sizes=FEATURE_SIZES,
    walk_steps=50,
    local_window_sizes=LOCAL_WINDOW_SIZES,
    embedding_sizes=EMBEDDING_SIZES,
    kernel_sizes=KERNEL_SIZES,
    batch_running_stats=False,
)


tensorboard = TensorBoardLogger("lightning_logs", name="new-contact-high-school-graph")
trainer = Trainer(
    logger=tensorboard,
    callbacks=[
        LearningRateMonitor(),
        EarlyStopping(
            "lr-Adam",
            patience=10000,  # we only want to stop once lr is too low
            mode="min",
            stopping_threshold=0.00001,
            check_on_train_epoch_end=True,
        ),
    ],
    accelerator="cpu",
    log_every_n_steps=1,
)

trainer.fit(model, contact_high_school)
