"""
SCRaWl experiments on the social contact datasets.

Run the experiments with:
```bash
python3 social_contact.py primary-school --walk-steps=50 --local-window-size=8 --use-lower-connections --use-upper-connections
```
"""
import torch
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from torch.nn import Linear, ReLU, Sequential

from scrawl.datasets import SchoolContactsDataModule
from scrawl.layers import PoolingMethod, SimplicialCRaWl
from scrawl.simplicial import SimplicialData

### CONFIGURATION ###
NUM_LAYERS = 4
FEATURE_SIZES = [0, 0, 0]
EMBEDDING_SIZES = [
    [32, 32, 32],
    [32, 32, 32],
    [32, 32, 32],
    [32, 32, 32],
]
KERNEL_SIZES = [8, 8, 8]


class SimplicialCRaWlLightning(LightningModule):
    """
    SCRaWl model for the social contact datasets.

    Parameters
    ----------
    num_classes : int
        The number of classes to predict.
    num_layers : int
        The number of SCRaWl layers to use.
    feature_sizes : list[int]
        The feature sizes of the simplicial complexes.
    walk_steps : int | tuple[int, int]
        The number of walk steps to use during training and validation. If a tuple is
        given, the first value is used for training and the second for validation.
    local_window_sizes : list[list[int]]
        The local window sizes for each layer.
    embedding_sizes : list[list[int]]
        The embedding sizes for each layer.
    kernel_sizes : list[int]
        The kernel sizes for each layer.
    pooling : PoolingMethod, default=PoolingMethod.MEAN
        The pooling method to use.
    dropout : float, default=0.0
        The dropout probability to use.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    use_lower_connections : bool, default=True
        Whether to use lower connections for random sampling.
    use_upper_connections : bool, default=True
        Whether to use upper connections for random sampling.
    """

    scrawl: SimplicialCRaWl

    def __init__(
        self,
        num_classes: int,
        num_layers: int,
        feature_sizes,
        walk_steps: int | tuple[int, int],
        local_window_sizes,
        embedding_sizes,
        kernel_sizes,
        *,
        pooling: PoolingMethod = PoolingMethod.MEAN,
        dropout=0.0,
        batch_running_stats: bool = True,
        use_lower_connections: bool = True,
        use_upper_connections: bool = True,
    ) -> None:
        """
        SCRaWl model for the social contact datasets.

        Parameters
        ----------
        num_classes : int
            The number of classes to predict.
        num_layers : int
            The number of SCRaWl layers to use.
        feature_sizes : list[int]
            The feature sizes of the simplicial complexes.
        walk_steps : int | tuple[int, int]
            The number of walk steps to use during training and validation. If a tuple
            is given, the first value is used for training and the second for
            validation.
        local_window_sizes : list[list[int]]
            The local window sizes for each layer.
        embedding_sizes : list[list[int]]
            The embedding sizes for each layer.
        kernel_sizes : list[int]
            The kernel sizes for each layer.
        pooling : PoolingMethod, default=PoolingMethod.MEAN
            The pooling method to use.
        dropout : float, default=0.0
            The dropout probability to use.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        use_lower_connections : bool, default=True
            Whether to use lower connections for random sampling.
        use_upper_connections : bool, default=True
            Whether to use upper connections for random sampling.
        """
        super().__init__()
        self.save_hyperparameters()

        if isinstance(walk_steps, tuple):
            self.walk_steps_train, self.walk_step_val = walk_steps
        else:
            self.walk_steps_train = self.walk_step_val = walk_steps

        self.scrawl = SimplicialCRaWl(
            num_layers,
            feature_sizes,
            local_window_sizes,
            embedding_sizes,
            kernel_sizes,
            pooling=pooling,
            dropout=dropout,
            batch_running_stats=batch_running_stats,
            use_lower_connections=use_lower_connections,
            use_upper_connections=use_upper_connections,
        )

        self.output = Sequential(
            Linear(embedding_sizes[-1][0], embedding_sizes[-1][0]),
            ReLU(),
            Linear(embedding_sizes[-1][0], num_classes),
        )

        # metrics
        self.train_acc = torchmetrics.Accuracy(
            "multiclass", num_classes=num_classes, top_k=1
        )
        self.val_acc = torchmetrics.Accuracy(
            "multiclass", num_classes=num_classes, top_k=1
        )

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer and learning rate scheduler.

        Returns
        -------
        OptimizerLRScheduler
            The optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.5, patience=10, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def forward(self, data: SimplicialData, walk_steps):
        """
        Forward pass of the model.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.
        walk_steps : int
            The number of walk steps to use.

        Returns
        -------
        torch.Tensor
            The output of the model.
        """
        input_data = SimplicialData(
            data.domain, dtype=torch.float32, device=data.device
        )
        embedding_data = self.scrawl(input_data, walk_steps)
        return self.output(embedding_data[0])

    def training_step(self, data, batch_idx):
        """
        Training step.

        Parameters
        ----------
        data : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        output = self(data, self.walk_steps_train)

        loss = F.cross_entropy(
            output[data.aux_tensor(0)], data[0][data.aux_tensor(0)].squeeze()
        )
        self.log("train_loss", loss, batch_size=1)

        # metrics
        self.train_acc(
            output[data.aux_tensor(0)], data[0][data.aux_tensor(0)].squeeze()
        )

        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False,
            batch_size=1,
        )

        return {"loss": loss}

    def validation_step(self, data, batch_idx):
        """
        Validation step.

        Parameters
        ----------
        data : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        output = self(data, self.walk_step_val)

        loss = F.cross_entropy(
            output[~data.aux_tensor(0)], data[0][~data.aux_tensor(0)].squeeze()
        )
        self.log("val_loss", loss, batch_size=1, prog_bar=True)

        # metrics
        self.val_acc(
            output[~data.aux_tensor(0)], data[0][~data.aux_tensor(0)].squeeze()
        )

        self.log("val_acc", self.val_acc, prog_bar=True)

        return {"loss": loss}


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", choices=["primary-school", "high-school"])
    parser.add_argument("--walk-steps", type=int, default=50)
    parser.add_argument("--local-window-size", type=int, default=8)
    parser.add_argument(
        "--use-lower-connections", action=argparse.BooleanOptionalAction, default=True
    )
    parser.add_argument(
        "--use-upper-connections", action=argparse.BooleanOptionalAction, default=True
    )
    args = parser.parse_args()

    kernel = list(map(lambda x: min(x, args.walk_steps), KERNEL_SIZES))

    social_contacts = SchoolContactsDataModule(
        args.dataset, train_size=0.6, max_rank=len(FEATURE_SIZES) + 1
    )
    model = SimplicialCRaWlLightning(
        num_classes=social_contacts.num_classes,
        num_layers=NUM_LAYERS,
        feature_sizes=FEATURE_SIZES,
        walk_steps=args.walk_steps,
        local_window_sizes=args.local_window_size,
        embedding_sizes=EMBEDDING_SIZES,
        kernel_sizes=kernel,
        batch_running_stats=False,
        use_lower_connections=args.use_lower_connections,
        use_upper_connections=args.use_upper_connections,
    )

    logging_name = (
        f"contact-{args.dataset}-l{args.walk_steps}-s{args.local_window_size}"
    )
    if not args.use_lower_connections:
        logging_name += "-no-lower"
        print("Not using lower connections for random sampling.")
    if not args.use_upper_connections:
        logging_name += "-no-upper"
        print("Not using upper connections for random sampling.")

    tensorboard = TensorBoardLogger("lightning_logs", name=logging_name)
    trainer = Trainer(
        logger=tensorboard,
        callbacks=[
            LearningRateMonitor(),
            EarlyStopping(
                "lr-Adam",
                patience=10000,  # we only want to stop once lr is too low
                mode="min",
                stopping_threshold=0.00001,
                check_on_train_epoch_end=True,
            ),
            EarlyStopping(
                "val_acc",
                patience=10e5,
                mode="max",
                stopping_threshold=0.995,
                verbose=True,
            ),
        ],
        accelerator="cpu",
        log_every_n_steps=1,
    )

    trainer.fit(model, social_contacts)
