"""
SCRaWl experiments on the semantic scholar dataset.
"""
from argparse import ArgumentParser

import numpy as np
import torch
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from toponetx.classes import SimplicialComplex
from torch.nn import Linear, ModuleList, ReLU, Sequential
from torch.utils.data import DataLoader

from scrawl.layers import PoolingMethod, SimplicialCRaWl
from scrawl.metrics import RegressionAccuracy
from scrawl.simplicial import SimplicialData
from scrawl.transformers import toponetx_to_data

### CONFIGURATION ###
NUM_LAYERS = 3
LOCAL_WINDOW_SIZES = [
    [4, 4, 4, 4, 4, 4, 4],
    [4, 4, 4, 4, 4, 4, 4],
    [4, 4, 4, 4, 4, 4, 4],
]
EMBEDDING_SIZES = [
    [32, 32, 32, 32, 32, 32, 32],
    [32, 32, 32, 32, 32, 32, 32],
    [32, 32, 32, 32, 32, 32, 32],
]


### CODE ###
class SemanticScholarDataModule(LightningDataModule):
    """
    Data module for the semantic scholar dataset.

    Parameters
    ----------
    percentage_missing : int
        The percentage of missing values to use.
    """

    def __init__(self, percentage_missing: int) -> None:
        """
        Data module for the semantic scholar dataset.

        Parameters
        ----------
        percentage_missing : int
            The percentage of missing values to use.
        """
        super().__init__()
        self.percentage_missing = percentage_missing

    def setup(self, stage: str | None = None) -> None:
        """
        Set-up the semantic scholar data module by loading the data.

        Parameters
        ----------
        stage : str, optional
            Not used.
        """
        damaged_simplices = np.load(
            f"data/S2AG/s2_3_collaboration_complex/150250_percentage_{self.percentage_missing}_input_damaged.npy",
            allow_pickle=True,
        )

        unknown_simplices = np.load(
            f"data/S2AG/s2_3_collaboration_complex/150250_percentage_{self.percentage_missing}_missing_values.npy",
            allow_pickle=True,
        )
        unknown_simplices = {
            frozenset(simplex): data
            for simplices in unknown_simplices
            for simplex, data in simplices.items()
        }

        simplicial_complex = SimplicialComplex()
        for simplices in damaged_simplices:
            for simplex, data in simplices.items():
                simplicial_complex.add_simplex(
                    simplex,
                    damaged_data=data,
                    true_data=unknown_simplices.get(simplex, data),
                    is_unknown=simplex in unknown_simplices,
                )

        self.train_data = toponetx_to_data(
            simplicial_complex, "damaged_data", torch.float32
        )
        self.val_data = toponetx_to_data(simplicial_complex, "true_data", torch.float32)

        for rank in range(simplicial_complex.dim + 1):
            attributes = simplicial_complex.get_simplex_attributes("is_unknown", rank)
            unknown_mask = torch.zeros(simplicial_complex.shape[rank], dtype=torch.bool)
            for i, simplex in enumerate(simplicial_complex.skeleton(rank)):
                if attributes[simplex.elements]:
                    unknown_mask[i] = True

            self.train_data.set_aux_tensor(rank, unknown_mask)
            self.val_data.set_aux_tensor(rank, unknown_mask)

    def train_dataloader(self):
        """
        Return the training dataloader.

        Returns
        -------
        DataLoader
            The training dataloader.
        """
        return DataLoader([self.train_data], batch_size=None)

    def val_dataloader(self):
        """
        Return the validation dataloader.

        Returns
        -------
        DataLoader
            The validation dataloader.
        """
        return DataLoader([self.val_data], batch_size=None)


class SimplicialCRaWlLightning(LightningModule):
    """
    SCRaWl model for the semantic scholar dataset.

    Parameters
    ----------
    num_layers : int
        The number of SCRaWl layers to use.
    feature_sizes : list[int]
        The feature sizes of the simplicial complexes.
    walk_steps : int | tuple[int, int]
        The number of walk steps to use during training and validation. If a tuple
        is given, the first value is used for training and the second for
        validation.
    local_window_sizes : list[list[int]]
        The local window sizes for each layer.
    embedding_sizes : list[list[int]]
        The embedding sizes for each layer.
    kernel_sizes : list[int]
        The kernel sizes for each layer.
    pooling : PoolingMethod, default=PoolingMethod.MEAN
        The pooling method to use.
    dropout : float, default=0.0
        The dropout probability to use.
    batch_running_stats : bool, default=True
        Whether to use batch running statistics for batch normalization.
    """

    scrawl: SimplicialCRaWl

    def __init__(
        self,
        num_layers: int,
        feature_sizes,
        walk_steps: int | tuple[int, int],
        local_window_sizes,
        embedding_sizes,
        kernel_sizes,
        *,
        pooling: PoolingMethod = PoolingMethod.MEAN,
        dropout=0.0,
        batch_running_stats: bool = True,
    ) -> None:
        """
        SCRaWl model for the semantic scholar dataset.

        Parameters
        ----------
        num_layers : int
            The number of SCRaWl layers to use.
        feature_sizes : list[int]
            The feature sizes of the simplicial complexes.
        walk_steps : int | tuple[int, int]
            The number of walk steps to use during training and validation. If a tuple
            is given, the first value is used for training and the second for
            validation.
        local_window_sizes : list[list[int]]
            The local window sizes for each layer.
        embedding_sizes : list[list[int]]
            The embedding sizes for each layer.
        kernel_sizes : list[int]
            The kernel sizes for each layer.
        pooling : PoolingMethod, default=PoolingMethod.MEAN
            The pooling method to use.
        dropout : float, default=0.0
            The dropout probability to use.
        batch_running_stats : bool, default=True
            Whether to use batch running statistics for batch normalization.
        """
        super().__init__()
        self.save_hyperparameters()

        if isinstance(walk_steps, tuple):
            self.walk_steps_train, self.walk_step_val = walk_steps
        else:
            self.walk_steps_train = self.walk_step_val = walk_steps

        self.scrawl = SimplicialCRaWl(
            num_layers,
            feature_sizes,
            local_window_sizes,
            embedding_sizes,
            kernel_sizes,
            pooling=pooling,
            dropout=dropout,
            batch_running_stats=batch_running_stats,
        )

        self.outputs = ModuleList(
            Sequential(
                Linear(embedding_sizes[-1][i], embedding_sizes[-1][i]),
                ReLU(),
                Linear(embedding_sizes[-1][i], 1),
            )
            for i in range(len(feature_sizes))
        )

        # metrics
        self.train_acc = RegressionAccuracy(0.05)
        self.val_acc = RegressionAccuracy(0.05)
        self.train_mae = torchmetrics.MeanAbsoluteError()
        self.val_mae = torchmetrics.MeanAbsoluteError()

        self.train_acc_per_rank = ModuleList(
            RegressionAccuracy(0.05) for _ in range(len(feature_sizes))
        )
        self.val_acc_per_rank = ModuleList(
            RegressionAccuracy(0.05) for _ in range(len(feature_sizes))
        )

    def forward(self, data: SimplicialData, walk_steps: int):
        """
        Forward pass of the model.

        Parameters
        ----------
        data : SimplicialData
            The simplicial data.
        walk_steps : int
            The number of walk steps to use.

        Returns
        -------
        torch.Tensor
            The output of the model.
        """
        embedding_data = self.scrawl(data, walk_steps)

        final_data = []
        for i, output_module in enumerate(self.outputs):
            final_data.append(output_module(embedding_data[i]))

        return final_data

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer and learning rate scheduler.

        Returns
        -------
        OptimizerLRScheduler
            The optimizer and learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.5, patience=10, verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        """
        Training step.

        Parameters
        ----------
        batch : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        data = self(batch, self.walk_steps_train)
        loss = sum(
            F.l1_loss(data[i][~batch.aux_tensor(i)], batch[i][~batch.aux_tensor(i)])
            for i in range(len(data))
        )
        self.log("train_loss", loss)

        # metrics
        for i in range(len(data)):
            self.train_acc(
                data[i][~batch.aux_tensor(i)], batch[i][~batch.aux_tensor(i)]
            )
            self.train_mae(
                data[i][~batch.aux_tensor(i)], batch[i][~batch.aux_tensor(i)]
            )

            if i < len(self.train_acc_per_rank):
                self.train_acc_per_rank[i](
                    data[i][~batch.aux_tensor(i)], batch[i][~batch.aux_tensor(i)]
                )

        self.log(
            "train_acc", self.train_acc, prog_bar=True, on_step=True, on_epoch=True
        )
        self.log("train_mae", self.train_mae, on_step=True, on_epoch=True)
        for i, acc in enumerate(self.train_acc_per_rank):
            self.log(f"train_acc_rank_{i}", acc, on_step=False, on_epoch=True)

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        """
        Validation step.

        Parameters
        ----------
        batch : SimplicialData
            The batch of data.
        batch_idx : int
            Not used.

        Returns
        -------
        dict[str, torch.Tensor]
            The loss.
        """
        data = self(batch, self.walk_step_val)
        loss = sum(
            F.l1_loss(data[i][batch.aux_tensor(i)], batch[i][batch.aux_tensor(i)])
            for i in range(len(data))
        )
        self.log("val_loss", loss, batch_size=1)

        # metrics
        for i in range(len(data)):
            self.val_acc(data[i][batch.aux_tensor(i)], batch[i][batch.aux_tensor(i)])
            self.val_mae(data[i][batch.aux_tensor(i)], batch[i][batch.aux_tensor(i)])

            if i < len(self.val_acc_per_rank):
                self.val_acc_per_rank[i](
                    data[i][batch.aux_tensor(i)], batch[i][batch.aux_tensor(i)]
                )

        self.log("val_acc", self.val_acc, on_step=True, on_epoch=True)
        self.log("val_mae", self.val_mae, on_step=True, on_epoch=True)
        for i, acc in enumerate(self.val_acc_per_rank):
            self.log(f"val_acc_rank_{i}", acc, on_step=False, on_epoch=True)

        return {"loss": loss}


def run(percentage_missing: int):  # numpydoc ignore=GL08
    semantic_scholar = SemanticScholarDataModule(percentage_missing)
    model = SimplicialCRaWlLightning(
        num_layers=NUM_LAYERS,
        feature_sizes=[1, 1, 1, 1, 1, 1, 1],
        walk_steps=5,
        local_window_sizes=LOCAL_WINDOW_SIZES,
        embedding_sizes=EMBEDDING_SIZES,
        kernel_sizes=[8, 8, 8, 8, 8, 8, 8],
        batch_running_stats=True,
    )

    tensorboard = TensorBoardLogger(
        "lightning_logs",
        name=f"new_semantic_scholar_{semantic_scholar.percentage_missing}",
    )
    trainer = Trainer(
        logger=tensorboard,
        callbacks=[
            LearningRateMonitor(),
            EarlyStopping(
                "lr-Adam",
                patience=10000,  # we only want to stop once lr is too low
                mode="min",
                stopping_threshold=0.00001,
                check_on_train_epoch_end=True,
            ),
        ],
        max_epochs=200,  # the network improves only marginally and has almost perfect accuracy already
        accelerator="cpu",
        log_every_n_steps=1,
    )
    trainer.fit(model, semantic_scholar)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("percentage_missing", type=int)
    parser.add_argument("--iterations", type=int, default=1)
    args = parser.parse_args()

    for _ in range(args.iterations):
        run(args.percentage_missing)
