"""
SCNN experiments on the social contact datasets.
"""
import torch
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from topomodelx.nn.simplicial.scnn_layer import SCNNLayer
from torch.nn import Linear, ReLU, Sequential

from scrawl.datasets import SchoolContactsDataModule
from scrawl.simplicial import SimplicialData

### CONFIGURATION ###
NUM_LAYERS = 4
FEATURE_SIZES = [0, 0, 0, 0]
LOCAL_WINDOW_SIZES = [[8, 8, 8, 8], [8, 8, 8, 8], [8, 8, 8, 8], [8, 8, 8, 8]]
EMBEDDING_SIZES = [
    [128, 128, 128, 128],
    [128, 128, 128, 128],
    [128, 128, 128, 128],
    [128, 128, 128, 128],
]
KERNEL_SIZES = [8, 8, 8, 8]


class SCNNLightning(LightningModule):
    """
    SCNN model for the social contact datasets.

    Parameters
    ----------
    in_channels : int
        The number of input channels.
    intermediate_channels : int
        The number of intermediate channels.
    out_channels : int
        The number of output channels.
    num_classes : int
        The number of classes to predict.
    conv_order_down : int
        The order of the convolution for the down pass.
    conv_order_up : int
        The order of the convolution for the up pass.
    aggr_norm : bool, default=False
        Whether to normalize the aggregation.
    update_func : str, default=None
        The update function to use.
    n_layers : int, default=2
        The number of SCNN layers to use.
    """

    def __init__(
        self,
        in_channels,
        intermediate_channels,
        out_channels,
        num_classes: int,
        conv_order_down,
        conv_order_up,
        aggr_norm=False,
        update_func=None,
        n_layers=2,
    ) -> None:
        """
        SCNN model for the social contact datasets.

        Parameters
        ----------
        in_channels : int
            The number of input channels.
        intermediate_channels : int
            The number of intermediate channels.
        out_channels : int
            The number of output channels.
        num_classes : int
            The number of classes to predict.
        conv_order_down : int
            The order of the convolution for the down pass.
        conv_order_up : int
            The order of the convolution for the up pass.
        aggr_norm : bool, default=False
            Whether to normalize the aggregation.
        update_func : str, default=None
            The update function to use.
        n_layers : int, default=2
            The number of SCNN layers to use.
        """
        super().__init__()
        self.n_layers = n_layers
        self.in_channels = in_channels
        self.save_hyperparameters()

        self.layers = torch.nn.ModuleDict()
        for i in range(3):
            self.layers[f"{i}-0"] = SCNNLayer(
                in_channels=in_channels,
                out_channels=intermediate_channels,
                conv_order_down=conv_order_down if i > 0 else 0,
                conv_order_up=conv_order_up if i < 2 else 0,
            )

            for j in range(1, n_layers):
                self.layers[f"{i}-{j}"] = SCNNLayer(
                    in_channels=intermediate_channels,
                    out_channels=out_channels,
                    conv_order_down=conv_order_down if i > 0 else 0,
                    conv_order_up=conv_order_up if i < 2 else 0,
                    aggr_norm=aggr_norm,
                    update_func=update_func,
                )

        self.output = Sequential(
            Linear(out_channels, out_channels),
            ReLU(),
            Linear(out_channels, num_classes),
        )

        # metrics
        self.train_acc = torchmetrics.Accuracy(
            "multiclass", num_classes=num_classes, top_k=1
        )
        self.val_acc = torchmetrics.Accuracy(
            "multiclass", num_classes=num_classes, top_k=1
        )

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer and learning rate scheduler.

        Returns
        -------
        OptimizerLRScheduler
            The optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.5, patience=10, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def forward(self, data: SimplicialData):
        """
        Forward pass of the model.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.

        Returns
        -------
        torch.Tensor
            The output of the model.
        """
        x_0 = torch.empty(
            (data.domain.shape[0], self.in_channels),
            dtype=torch.float32,
            device=data.device,
        )
        x_1 = torch.ones(
            (data.domain.shape[1], self.in_channels),
            dtype=torch.float32,
            device=data.device,
        )
        x_2 = torch.ones(
            (data.domain.shape[2], self.in_channels),
            dtype=torch.float32,
            device=data.device,
        )

        torch.nn.init.xavier_normal_(x_0)
        torch.nn.init.xavier_normal_(x_1)
        torch.nn.init.xavier_normal_(x_2)

        for i in range(self.n_layers):
            x_0 = self.layers[f"0-{i}"](x_0, None, data.domain.up_laplacian_matrix(0))
            x_1 = self.layers[f"1-{i}"](
                x_1,
                data.domain.down_laplacian_matrix(1),
                data.domain.up_laplacian_matrix(1),
            )
            x_2 = self.layers[f"2-{i}"](x_2, data.domain.down_laplacian_matrix(2), None)

        out = self.output(x_0)
        return out

    def training_step(self, data, batch_idx):
        """
        Training step.

        Parameters
        ----------
        data : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        output = self(data)

        loss = F.cross_entropy(
            output[data.aux_tensor(0)], data[0][data.aux_tensor(0)].squeeze()
        )
        self.log("train_loss", loss, batch_size=1)

        # metrics
        self.train_acc(
            output[data.aux_tensor(0)], data[0][data.aux_tensor(0)].squeeze()
        )

        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
            batch_size=1,
        )

        return {"loss": loss}

    def validation_step(self, data, batch_idx):
        """
        Validation step.

        Parameters
        ----------
        data : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        output = self(data)

        loss = F.cross_entropy(
            output[~data.aux_tensor(0)], data[0][~data.aux_tensor(0)].squeeze()
        )
        self.log("val_loss", loss, batch_size=1, prog_bar=True)

        # metrics
        self.val_acc(
            output[~data.aux_tensor(0)], data[0][~data.aux_tensor(0)].squeeze()
        )

        self.log("val_acc", self.val_acc, prog_bar=True)

        return {"loss": loss}


contact_primary_school = SchoolContactsDataModule(
    "primary-school", train_size=0.6, max_rank=len(FEATURE_SIZES) - 1
)
model = SCNNLightning(
    in_channels=64,
    intermediate_channels=64,
    out_channels=64,
    num_classes=contact_primary_school.num_classes,
    conv_order_down=1,
    conv_order_up=1,
    aggr_norm=False,
    update_func="relu",
    n_layers=5,
)


tensorboard = TensorBoardLogger("lightning_logs", name="contact-primary-school-scnn")
trainer = Trainer(
    logger=tensorboard,
    callbacks=[
        LearningRateMonitor(),
        EarlyStopping(
            "lr-Adam",
            patience=10000,  # we only want to stop once lr is too low
            mode="min",
            stopping_threshold=0.00001,
            check_on_train_epoch_end=True,
        ),
        EarlyStopping(
            "val_acc", patience=10e5, mode="max", stopping_threshold=0.995, verbose=True
        ),
    ],
    accelerator="cpu",
    log_every_n_steps=1,
)

trainer.fit(model, contact_primary_school)
