from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any

import torch
from torch import Tensor, nn, optim
from torch.optim.lr_scheduler import (
    ConstantLR,
    ExponentialLR,
    LinearLR,
    LRScheduler,
    ReduceLROnPlateau,
    StepLR,
)
from torch.utils.data import DataLoader
from tqdm import tqdm

from compression_autoencoder.utils.early_stopping import EarlyStopping
from compression_autoencoder.utils.history import History
from compression_autoencoder.utils.losses import (
    BetaVAELoss,
    CustomLoss,
    GaussianKLDivLoss,
    GaussianRenyiAlphaDivLoss,
    SameSTDGaussianRenyiAlphaDivLoss,
)


class CustomModule(nn.Module, ABC):
    """Base class for a custom PyTorch module.
    Provides utility functions for building and traversing layers.
    """

    # Supported functions
    activation_functions: dict[str, nn.Module] = {
        "relu": nn.ReLU(),
        "elu": nn.ELU(),
        "sigmoid": nn.Sigmoid(),
        "softmax": nn.Softmax(dim=-1),
        "tanh": nn.Tanh(),
    }
    loss_functions: dict[str, nn.Module | CustomLoss] = {
        "mse": nn.MSELoss(),
        "renyi2_same_std_gaussian": SameSTDGaussianRenyiAlphaDivLoss(alpha=2),
        "kl_same_std_gaussian": SameSTDGaussianRenyiAlphaDivLoss(alpha=1),
        "renyi2_gaussian": GaussianRenyiAlphaDivLoss(alpha=2),
        "kl_gaussian": GaussianKLDivLoss(),
        "vae": BetaVAELoss(beta=1.0),
        "categorical_crossentropy": nn.CrossEntropyLoss(),
    }
    lr_schedulers: dict[str, type[LRScheduler]] = {
        "step": StepLR,
        "exponential": ExponentialLR,
        "linear": LinearLR,
        "constant": ConstantLR,
        "plateau": ReduceLROnPlateau,
    }

    def __init__(
        self,
        activation_func: nn.Module | str = "relu",
        loss_func: nn.Module | CustomLoss | str = "mse",
        last_activation_func: nn.Module | str | None = None,
        input_scaler: nn.Module | None = None,
        early_stopping_kwargs: dict[str, Any] | None = None,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__()
        if device is None:
            device = "cpu"
        self.device = torch.device(device)

        activation_f = self._setup_activation_func(activation_func)
        if activation_f is not None:
            self.activation_func = activation_f
        self.last_activation_func = self._setup_activation_func(
            last_activation_func, require=False
        )
        self.loss_func = self._setup_loss_func(loss_func)

        self.input_scaler = input_scaler.to(self.device) if input_scaler else None
        self.early_stopping = self._setup_early_stopping(early_stopping_kwargs)

        self.optimizer: optim.Optimizer
        self.scheduler: LRScheduler | None = None

    def setup_scheduler(self, kwargs: dict[str, Any] | None) -> LRScheduler | None:
        """
        Setup the learning rate scheduler. The kwargs must contain a key
        "type" that specifies the type of scheduler to use.

        Args:
            kwargs (dict[str, Any] | None): the keyword arguments to pass to the
            scheduler. If None, no scheduler is created.

        Returns:
            LRScheduler | None: the learning rate scheduler. If None, no scheduler
            is created.
        """
        if kwargs is None:
            return None
        scheduler_type = kwargs.pop("type")  # Raises an error if not found
        assert scheduler_type in self.lr_schedulers, (
            f"Scheduler {scheduler_type} not supported."
        )
        return self.lr_schedulers[scheduler_type](self.optimizer, **kwargs)

    @staticmethod
    def _setup_early_stopping(kwargs: dict[str, Any] | None) -> EarlyStopping | None:
        """
        Setup the early stopping.

        Args:
            kwargs (dict[str, Any] | None): the keyword arguments to pass to the
            early stopping. If None, no early stopping is created.

        Returns:
            EarlyStopping | None: the early stopping. If None, no early stopping
            is created.
        """
        if kwargs is None:
            return None
        return EarlyStopping(**kwargs)

    @staticmethod
    def _setup_activation_func(
        func: nn.Module | str | None, require: bool = True
    ) -> nn.Module | None:
        """
        Setup the activation function. If `require` is True, an activation
        function must be provided.

        Args:
            func (nn.Module | str | None): the activation function to use. If None,
            no activation function is created.
            require (bool): whether an activation function is required. If True,
            an error is raised if no activation function is provided.

        Returns:
            nn.Module | None: the activation function. If None, no activation
            function is created.
        """
        if func is None:
            if require:
                raise ValueError("Activation function must be provided.")
            return None
        if isinstance(func, str):
            if func == "linear" and not require:
                return None
            assert func in CustomModule.activation_functions, (
                f"Activation function {func} not supported."
            )
            return CustomModule.activation_functions[func]
        if isinstance(func, nn.Module):
            return func

        raise TypeError(f"Activation function {func} not supported.")

    @staticmethod
    def _setup_loss_func(func: nn.Module | CustomLoss | str) -> nn.Module | CustomLoss:
        """
        Setup the loss function.

        Args:
            func (nn.Module | CustomLoss | str): the loss function to use.

        Returns:
            nn.Module | CustomLoss: the loss function.
        """
        if isinstance(func, str):
            assert func in CustomModule.loss_functions, (
                f"Loss function {func} not supported."
            )
            return CustomModule.loss_functions[func]
        if isinstance(func, (nn.Module, CustomLoss)):
            return func
        raise TypeError(f"Loss function {func} not supported.")

    def get_last_lr(self) -> float:
        """
        Get the last learning rate of the optimizer.

        Returns:
            float: the last learning rate of the optimizer.
        """
        if self.scheduler is None:
            return self.optimizer.param_groups[0]["lr"]
        return self.scheduler.get_last_lr()[0]

    def step_scheduler(self, loss: nn.Module | float, epoch: int | None = None) -> None:
        """
        Perform a step of the learning rate scheduler if it exists.

        Args:
            loss (nn.Module | float): the loss to pass to the scheduler. If the
            scheduler is ReduceLROnPlateau, the loss is passed to the
            scheduler. Otherwise, the epoch is passed.
            epoch (int | None, optional): the current epoch. Defaults to None.
        """
        if self.scheduler is None:
            return
        if isinstance(self.scheduler, ReduceLROnPlateau):
            # ReduceLROnPlateau needs the loss to be passed to step
            self.scheduler.step(loss)  # type:ignore
        else:
            # Other schedulers need the epoch to be passed to step
            self.scheduler.step(epoch)

    def step_ealy_stopping(self, val_loss: float) -> bool:
        """
        Perform a step of the early stopping if it exists.

        Args:
            val_loss (float): the validation loss to pass to the early stopping.

        Returns:
            bool: True if the training should be stopped, False otherwise.
        """
        if self.early_stopping is None:
            return False
        return self.early_stopping.step(val_loss)

    def count_params(self) -> tuple[int, int]:
        """
        Count the number of parameters in the model.

        Returns:
            tuple[int, int]: the total number of parameters and the number of
            trainable parameters.
        """
        num_params = sum(p.numel() for p in self.parameters())
        num_trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad
        )
        return num_params, num_trainable_params

    def scale_input(self, x: Tensor) -> Tensor:
        """
        Scale the input using the input scaler if it exists.

        Args:
            x (Tensor): the input to scale.
        """
        if self.input_scaler is not None:
            x = self.input_scaler(x)
        return x

    @staticmethod
    def _forward_layers(
        x: Tensor,
        layers: nn.ModuleList,
        activation_func: nn.Module,
        activate_last_layer: bool = False,
    ) -> Tensor:
        """
        Flow the input through the layers, applying the activation function. If
        `activate_last_layer` is True, apply the activation function to the
        last layer as well.

        Args:
            x (Tensor): the input.
            layers (nn.ModuleList): the layers to apply.
            activation_func (nn.Module): the activation function to apply.
            activate_last_layer (bool): whether to apply the activation
            function to the output of the last layer.

        Returns:
            Tensor: the output.
        """
        for i, layer in enumerate(layers):
            # This avoids applying the activation to the input. Activating
            # BEFORE the layer also prevents double activation of the last layer.
            if i > 0:
                x = activation_func(x)
            x = layer(x)

        if activate_last_layer:
            x = activation_func(x)

        return x

    @staticmethod
    def _build_layers(
        layer_shapes: list[tuple[int, int, bool]],
        device: str | torch.device | None = None,
    ) -> nn.ModuleList:
        """
        Create a list of layers from the given layer sizes.
        Each layer is a Linear layer with the given input and output sizes.
        Use ModuleList to store the layers so that they are registered
        to autograd.

        Args:
            layer_sizes (list[tuple[int, int, bool]]): the layers input
            and output sizes, and whether to use a bias.

        Returns:
            nn.ModuleList: the list of layers.
        """
        layers = []
        for in_f, out_f, use_bias in layer_shapes:
            layers.append(nn.Linear(in_f, out_f, bias=use_bias, device=device))

        return nn.ModuleList(layers)

    @staticmethod
    def _freeze_layers(layers: nn.ModuleList, clear_grad: bool = False) -> None:
        """
        Freeze the layers to prevent them from being updated during
        optimization.

        Args:
            layers (nn.ModuleList): the layers to freeze.
            clear_grad (bool, optional): whether to clear the gradients of the
            layers. Defaults to False.
        """
        for param in layers.parameters():
            param.requires_grad = False
            if clear_grad:
                param.grad = None

    @staticmethod
    def _unfreeze_layers(layers: nn.ModuleList) -> None:
        """
        Unfreeze the layers to allow them from being updated during
        training.

        Args:
            layers (nn.ModuleList): the layers to unfreeze.
        """
        for param in layers.parameters():
            param.requires_grad = True

    @staticmethod
    def _extract_weights(layers: nn.ModuleList) -> Tensor:
        """
        Extract the weights from the layers into a single tensor. The flattened
        weights are concatenated into a single tensor [W1, B1, W2, B2, ...].

        Args:
            layers (nn.ModuleList): the layers to extract the weights from.

        Returns:
            Tensor: the tensor holding all the model parameters.
        """
        return torch.cat([params.view(-1) for params in layers.parameters()])

    @staticmethod
    def _inject_weights(
        layers: nn.ModuleList,
        weights: Tensor,
        layer_shapes: list[tuple[int, int, bool]],
    ) -> None:
        """
        Set the weights and biases of the layers from a single flattened tensor.

        Args:
            layers (nn.ModuleList): the layers to set the weights for.
            weights (Tensor): the flattened weights tensor (1D).
            layer_shapes (list[tuple[int, int, bool]]): the layers input and
            output sizes, and whether to use a bias.
        """
        idx = 0
        for layer, (in_f, out_f, use_bias) in zip(layers, layer_shapes, strict=False):
            assert isinstance(layer, nn.Linear)
            w_len = out_f * in_f
            weight = weights[idx : idx + w_len].reshape(out_f, in_f)
            layer.weight.data.copy_(weight)
            idx += w_len
            if use_bias:
                bias = weights[idx : idx + out_f]
                layer.bias.data.copy_(bias)
                idx += out_f

    @staticmethod
    def _count_params_from_shape(layer_shapes: list[tuple[int, int, bool]]) -> int:
        """
        Count the number of parameters in the layers from their shapes.

        Args:
            layer_shapes (list[tuple[int, int, bool]]): the layers input and
            output sizes, and whether to use a bias.

        Returns:
            int: the total number of parameters.
        """
        num_params = 0
        for in_f, out_f, use_bias in layer_shapes:
            num_params += out_f * in_f
            if use_bias:
                num_params += out_f
        return num_params

    @staticmethod
    def _forward_with_weights(
        x: Tensor,
        weights: Tensor,
        layer_shapes: list[tuple[int, int, bool]],
        activation_func: nn.Module,
        activate_last_layer: bool = False,
    ) -> Tensor:
        """
        Flows the input `x` through a batch of models defined by `weights`.

        The weights and biases for each model in the batch are extracted from
        the `weights` tensor. `torch.bmm` is used for the core computation to
        support per-sample weights, which is not possible with a standard
        `nn.Linear` layer.

        This function supports symmetrical broadcasting:
        - **Many-to-Many**: `x` and `weights` have the same batch size `B`.
        - **One-to-Many**: A single model (`weights` B=1) is applied to a batch
        of inputs (`x` B>1).
        - **Many-to-One**: A single input (`x` B=1) is fed to a batch of models
        (`weights` B>1).

        It also handles 1D, 2D, and 3D inputs for `x`, and 1D and 2D inputs for
        `weights`.

        Args:
            x (Tensor): The input tensor. Supported shapes are `[B, S, in_f]`
                (batched sequence), `[B, in_f]` (batched state), or `[in_f]`
                (single unbatched state). Its batch dimension `B` can be 1 to
                enable broadcasting.
            weights (Tensor): The tensor of flattened model parameters, with
                shape `[B, P]` or `[P]`, where `P` is the total number of
                parameters. Its batch dimension `B` can be 1 to enable
                broadcasting.
            layer_shapes (list[tuple[int, int, bool]]): The layers' input and
                output sizes, and whether a bias is used.
            activation_func (nn.Module): The activation function to apply
                between layers.
            activate_last_layer (bool, optional): Whether to apply the
                activation function to the final layer's output. Defaults to
                False.

        Returns:
            Tensor: The network output. The batch dimension will be the result
            of broadcasting the batch sizes of `x` and `weights`. The remaining
            shape will be `[S, out_f]` or `[out_f]` depending on the input `x`.
        """
        # Ensure x and weights are at least 2D
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if weights.dim() == 1:
            weights = weights.unsqueeze(0)

        b_x, b_w = x.shape[0], weights.shape[0]

        # Negotiate batch sizes and handle broadcasting if needed
        if b_x != b_w:
            if b_x == 1:  # Broadcast input
                expand_shape = [b_w] + [-1] * (x.dim() - 1)
                x = x.expand(*expand_shape)
            elif b_w == 1:  # Broadcast weights
                weights = weights.expand(b_x, -1)
            else:  # In this case, the batch sizes are incompatible
                raise ValueError(
                    f"Incompatible batch sizes for x and weights: {b_x} and {b_w}. "
                )

        # If the input is 2D, we add a dummy sequence dimension of size 1.
        # This allows us to handle both 2D and 3D inputs uniformly.
        was_2d = x.dim() == 2
        if was_2d:
            x = x.unsqueeze(1)  # Shape becomes [B, 1, in_f]

        # Transpose to [B, in_f, S] so that we can use bmm on batches.
        x = x.transpose(1, 2)

        idx = 0
        for i, (in_f, out_f, use_bias) in enumerate(layer_shapes):
            # Weight and bias extraction is the same
            w_len = out_f * in_f
            # The weights inside Linear are already transposed, and when they
            # are stored they are flattened. So reconstructing the shape
            # requires reshaping the weights to [B, out_f, in_f] which is
            # already W^T.
            weight = weights[:, idx : idx + w_len].reshape(-1, out_f, in_f)
            idx += w_len

            bias = None
            if use_bias:
                bias = weights[:, idx : idx + out_f]
                idx += out_f

            # This avoids applying the activation to the input. Activating
            # BEFORE the layer also prevents double activation of the last
            # layer.
            if i > 0:
                x = activation_func(x)

            # Apply the same layer of all the models to the input in parallel
            # using batch matrix multiplication. BMM takes two matrices of
            # shape [B, N, M] and [B, M, K], and returns a matrix of shape [B,
            # N, K]. x is [B, in_f, S], weight is [B, out_f, in_f], so we need
            # to pre-multiply x by weight, and get the output [B, out_f, S].
            x = torch.bmm(weight, x)
            if bias is not None:
                # The bias [B, out_f] is unsqueezed to [B, out_f, 1]
                # and broadcasted over the S dimension.
                x = x + bias.unsqueeze(-1)

        if activate_last_layer:
            x = activation_func(x)

        # Transpose back from [B, out_f, S] to [B, S, out_f]
        x = x.transpose(1, 2)

        # If the input was 2D, we remove the dummy sequence dimension.
        if was_2d:
            x = x.squeeze(1)

        return x

    def fit(  # noqa: C901
        self,
        train_loader: DataLoader,
        val_loader: DataLoader | None = None,
        num_epochs: int = 100,
        verbose: int = 1,
    ) -> History:
        """
        Train the model.

        Args:
            train_loader (DataLoader): the training data loader.
            val_loader (DataLoader | None): the validation data loader.
            num_epochs (int): the number of epochs to train for.
            verbose (bool): level of verbosity. If 0, no output is printed.
                If 1, track the pogress of training, if 2, print the
                training and validation losses at each epoch.

        Returns:
            History: the history of the training process.
        """
        history = History({"train_loss": [], "val_loss": []})

        r_epochs: range | tqdm[int]
        if verbose != 1:
            r_epochs = range(num_epochs)
        else:
            r_epochs = tqdm(range(num_epochs), desc="Training", unit="epoch")

        for epoch in r_epochs:
            if verbose == 2:
                print(f"Epoch {epoch + 1}/{num_epochs} ", end="")

            train_losses = self.train_epoch(train_loader)
            if isinstance(train_losses, float):
                history.append("train_loss", train_losses)
            else:
                history.append_from_dict(train_losses)

            if verbose == 2:
                print(f"train_loss: {history['train_loss'][-1]:.6f} ", end="")

            if val_loader is not None:
                val_losses = self.test_epoch(val_loader)
                if isinstance(val_losses, float):
                    val_loss = val_losses
                    history.append("val_loss", val_losses)
                else:
                    val_loss = val_losses["val_loss"]
                    history.append_from_dict(val_losses)

                if verbose == 2:
                    print(f"val_loss: {history['val_loss'][-1]:.6f} ", end="")

                if self.step_ealy_stopping(val_loss):
                    if verbose > 0:
                        print(f"\nEarly stopping at epoch {epoch + 1}")
                    break

                old_lr = self.get_last_lr()
                self.step_scheduler(val_loss, epoch)
                new_lr = self.get_last_lr()
                if verbose > 0 and old_lr != new_lr:
                    print(f"lr: {old_lr:.6f} -> {new_lr:.6f} ", end="")

            if verbose == 2:
                print()  # new line

        return history

    def save(self, path: str | Path) -> None:
        """
        Save this model to `path`.

        Args:
            path (str | Path): the path where to save this model
        """
        torch.save(self.state_dict(), path)

    def load(self, path: str | Path) -> None:
        """
        Load a model or its weights from a file at `path`.

        Args:
            path (str | Path): the path where to find the model file
        """
        self.load_state_dict(
            torch.load(path, weights_only=True, map_location=self.device)
        )

    @abstractmethod
    def train_epoch(self, data_loader: DataLoader) -> float | dict[str, float]:
        """
        Train the model for one epoch on the data provided by the data
        loader.

        Args:
            data_loader (DataLoader): the train data loader.

        Returns:
            float | dict[str, float]: the training loss for the epoch, or
            a dictionary of losses if multiple losses are used. At least one of
            the keys in the dictionary must be "train_loss".
        """
        self.train()
        return 0.0

    @abstractmethod
    def test_epoch(self, data_loader: DataLoader) -> float | dict[str, float]:
        """
        Test the model for one epoch on the data provided by the data
        loader.

        Args:
            data_loader (DataLoader): the test data loader.

        Returns:
            float | dict[str, float]: the val/test loss for the epoch, or
            a dictionary of losses if multiple losses are used. At least one of
            the keys in the dictionary must be "val_loss".
        """
        self.eval()
        return 0.0

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:
        """
        Flow the input through the model, and return the output.

        Args:
            x (Tensor): the input.

        Returns:
            Tensor: the output.
        """
