import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader

from compression_autoencoder.utils.custom_module import CustomModule
from compression_autoencoder.utils.losses import CustomLoss


class Policy(CustomModule):
    """A class for a policy."""

    def __init__(
        self,
        layer_shapes: list[tuple[int, int, bool]],
        activation_func: nn.Module | str = "relu",
        last_activation_func: nn.Module | str | None = None,
        loss_func: nn.Module | CustomLoss | str = "mse",
        learning_rate: float = 0.01,
        input_scaler: nn.Module | None = None,
        weight_decay: float = 0,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__(
            activation_func=activation_func,
            loss_func=loss_func,
            last_activation_func=last_activation_func,
            input_scaler=input_scaler,
            device=device,
        )
        self.layer_shapes = layer_shapes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

        self.num_expected_params = self._count_params_from_shape(layer_shapes)
        self.layers = self._build_layers(layer_shapes, self.device)
        self.optimizer = optim.Adam(
            self.parameters(), lr=learning_rate, weight_decay=weight_decay
        )

    @staticmethod
    def init_like(other: "Policy") -> "Policy":
        """Initialize a new Policy with the same layer shapes as another
        Policy.

        Args:
            other (Policy): the other Policy to copy the layer
            shapes from.

        Returns:
            Policy: a new Policy with the same layer shapes as
            the other.
        """
        return Policy(
            layer_shapes=other.layer_shapes,
            activation_func=other.activation_func,
            last_activation_func=other.last_activation_func,
            loss_func=other.loss_func,
            learning_rate=other.learning_rate,
        )

    def extract_weights(self) -> Tensor:
        """
        Extract the weights from the layers into a single tensor. The flattened
        weights are concatenated into a single tensor [W1, B1, W2, B2, ...].

        Returns:
            Tensor: the tensor holding all the model parameters.
        """
        return self._extract_weights(self.layers)

    def inject_weights(self, weights: Tensor) -> None:
        """
        Inject the weights into the layers from a single tensor. The flattened
        weights are expected to be in the same order as extracted by
        `extract_weights`.

        Args:
            weights (Tensor): the tensor holding the model parameters.
            It is assumed to have shape `[P]`, where `P` is the number of
            parameters.
        """
        self._inject_weights(self.layers, weights, self.layer_shapes)

    def forward(self, x: Tensor, weights: Tensor | None = None) -> Tensor:  # type: ignore
        """
        Flow the input through the policy. If weights are provided, then
        the weights are used instead of the actual model parameters,
        effectively running many forward passes in parallel on many models with
        the same architecture.

        Args:
            x (Tensor): the input.
            weights (Tensor | None): the tensor holding the models parameters.
            It is assumed to have shape `[B, P]` or `[P]`, where `B` is the
            batch size and `P` is the number of parameters. If None then the
            actual model parameters are used Defaults to None.

        Returns:
            Tensor: the output of the policy.
        """
        x = self.scale_input(x)
        if weights is None:  # Normal forward pass
            out = self._forward_layers(x, self.layers, self.activation_func)
        else:
            # Parallel forward pass with weights
            assert weights.shape[-1] == self.num_expected_params, (
                f"Weights tensor has shape {weights.shape}, "
                f"but the last dimension should be"
                f"{self.num_expected_params}."
            )
            out = self._forward_with_weights(
                x, weights, self.layer_shapes, self.activation_func
            )

        if self.last_activation_func is not None:
            out = self.last_activation_func(out)
        return out

    def predict(
        self, x: Tensor, weights: Tensor | None = None, deterministic: bool = False
    ) -> Tensor:
        """
        Make a prediction using the policy.

        Args:
            x (Tensor): the input tensor.
            weights (Tensor | None): the weights to use for the prediction.
            If None, the policy's own weights are used. Defaults to None.
            deterministic (bool): whether to use deterministic actions.
            Defaults to False.

        Returns:
            Tensor: the predicted output.
        """
        return self.forward(x, weights)

    def train_epoch(self, data_loader: DataLoader) -> float | dict[str, float]:
        return super().train_epoch(data_loader)

    def test_epoch(self, data_loader: DataLoader) -> float | dict[str, float]:
        return super().test_epoch(data_loader)

    def should_stop(
        self,
        predictions: Tensor,
        epoch: int,
        targets: Tensor,
        epochs: int | None = None,
        epsilons: Tensor | None = None,
    ) -> bool:
        """
        Check if the training should stop. Either the number of epochs
        is reached or the values are ALL epsilon-close to the targets.

        Args:
            predictions (Tensor): the predictions of the model.
            epoch (int): the current epoch.
            targets (Tensor): the target tensor.
            epochs (int | None): the number of epochs to train for. Defaults to
            None.
            epsilons (Tensor | None): the epsilon value to train for. Defaults
            to None.

        Returns:
            bool: True if the training should stop, False otherwise.
        """
        if epochs is not None and epoch >= epochs:
            return True
        if epsilons is not None:
            with torch.no_grad():
                return bool(
                    torch.all(torch.abs(predictions - targets) <= epsilons).item()
                )
        return False

    def fit_on_targets(
        self,
        targets: Tensor,
        epochs: int | None = None,
        epsilons: Tensor | None = None,
        threshold: float | None = None,
    ) -> tuple[int, Tensor, float]:
        """
        Train the policy on a target, using gradient descent. The input is a
        single 1.

        At least one between epochs, epsilons, and threshold must be provided.
        If all are None, then the function will raise an error. Otherwise, the
        function will stop when the first of the conditions is met:
        - the number of epochs is reached
        - the values are ALL epsilon-close to the targets
        - the loss is below the threshold
        This process is set up so that it can work with an arbitrary number of
        targets, and loss function.

        Args:
            targets (Tensor): the target tensor.
            epochs (int | None): the number of epochs to train for. Defaults to
            None.
            epsilon (int | None): the epsilon value to train for. Defaults to
            None.
            threshold (float | None): the threshold value to train for. Defaults
            to None.

        Returns:
            tuple[int, Tensor, float]: the number of epochs trained and the
            final errors and loss.
        """
        return self.fit_on_dataset(
            features=torch.tensor([[1.0]]),
            targets=targets,
            epochs=epochs,
            epsilons=epsilons,
            threshold=threshold,
        )

    def fit_on_dataset(
        self,
        features: Tensor,
        targets: Tensor,
        epochs: int | None = None,
        epsilons: Tensor | None = None,
        threshold: float | None = None,
    ) -> tuple[int, Tensor, float]:
        """
        Train the policy on a features->targets dataset, using gradient descent.

        At least one between epochs, epsilons, and threshold must be provided.
        If all are None, then the function will raise an error. Otherwise, the
        function will stop when the first of the conditions is met:
        - the number of epochs is reached
        - the values are ALL epsilon-close to the targets
        - the loss is below the threshold
        This process is set up so that it can work with an arbitrary shape of
        features and targets, and loss function.
        The feature and target tensore must match in the first dimension.

        Args:
            featuers (Tensor): the features tensor.
            targets (Tensor): the target tensor.
            epochs (int | None): the number of epochs to train for. Defaults to
            None.
            epsilon (int | None): the epsilon value to train for. Defaults to
            None.
            threshold (float | None): the threshold value to train for. Defaults
            to None.

        Returns:
            tuple[int, Tensor, float]: the number of epochs trained and the
            final errors and loss.
        """
        if epochs is None and epsilons is None and threshold is None:
            raise ValueError("Either epochs, epsilons, or threshold must be provided.")
        if features.shape[0] != targets.shape[0]:
            raise ValueError(
                "The features and targets tensors must have the same first dimension."
            )

        self.train()

        # Set up first epoch
        self.optimizer.zero_grad()
        predictions = self.forward(features)
        epoch = 0
        loss = self.loss_func(predictions, targets).item()
        while not self.should_stop(predictions, epoch, targets, epochs, epsilons):
            T_loss = self.loss_func(predictions, targets)
            T_loss.backward()
            loss = T_loss.item()
            if threshold is not None and loss < threshold:
                break
            self.optimizer.step()
            epoch += 1
            # Set up next epoch
            self.optimizer.zero_grad()
            predictions = self.forward(features)

        return epoch, torch.abs(predictions - targets), loss
