from typing import Any

import torch
from torch import Tensor, nn, optim

from compression_autoencoder.utils.custom_module import CustomModule
from compression_autoencoder.utils.losses import CustomLoss


class Autoencoder(CustomModule):
    """Base class for an autoencoder model.
    Provides the basic structure for an autoencoder, including the encoder
    and decoder layers, as well as utility functions.
    """

    def __init__(
        self,
        input_dim: int,
        latent_dim: int,
        encoder_layers_shapes: list[tuple[int, int, bool]],
        decoder_layers_shapes: list[tuple[int, int, bool]],
        activation_func: nn.Module | str = "relu",
        loss_func: nn.Module | CustomLoss | str = "mse",
        learning_rate: float = 1e-3,
        scheduler_kwargs: dict[str, Any] | None = None,
        early_stopping_kwargs: dict[str, Any] | None = None,
        input_scaler: nn.Module | None = None,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__(
            activation_func=activation_func,
            loss_func=loss_func,
            input_scaler=input_scaler,
            early_stopping_kwargs=early_stopping_kwargs,
            device=device,
        )
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        self.encoder_layers = self._build_layers(encoder_layers_shapes, self.device)
        self.decoder_layers = self._build_layers(decoder_layers_shapes, self.device)

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.scheduler = self.setup_scheduler(scheduler_kwargs)

    def encode(self, x: Tensor) -> Tensor:
        """
        Encode the input into the latent space.

        Args:
            x (Tensor): the input.

        Returns:
            Tensor: the latent representation.
        """
        return self._forward_layers(x, self.encoder_layers, self.activation_func)

    def decode(self, x: Tensor) -> Tensor:
        """
        Decode the latent representation back to the input space.

        Args:
            x (Tensor): the latent representation.

        Returns:
            Tensor: the reconstructed input.
        """
        return self._forward_layers(x, self.decoder_layers, self.activation_func)

    def freeze_decoder(self) -> None:
        """
        Freeze the decoder layers to prevent them from being updated during
        optimmization.
        """
        self._freeze_layers(self.decoder_layers)

    def unfreeze_decoder(self) -> None:
        """
        Unfreeze the decoder layers to allow them to be updated during
        training.
        """
        self._unfreeze_layers(self.decoder_layers)
