import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader

from compression_autoencoder.autoencoders.autoencoder import Autoencoder
from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.losses import CustomLoss


class RLNeuralContinuousAutoencoder(Autoencoder):
    def __init__(
        self,
        sample_policy: Policy,
        latent_dim: int,
        encoder_layers_shapes: list[tuple[int, int, bool]],
        decoder_layers_shapes: list[tuple[int, int, bool]],
        activation_func: nn.Module | str = "relu",
        loss_func: nn.Module | CustomLoss | str = "mse",
        learning_rate: float = 1e-3,
        scheduler_kwargs: dict[str, float] | None = None,
        early_stopping_kwargs: dict[str, float] | None = None,
        input_scaler: nn.Module | None = None,
        device: str | torch.device | None = None,
    ) -> None:
        n_params, _ = sample_policy.count_params()

        super().__init__(
            input_dim=n_params,
            latent_dim=latent_dim,
            encoder_layers_shapes=encoder_layers_shapes,
            decoder_layers_shapes=decoder_layers_shapes,
            activation_func=activation_func,
            loss_func=loss_func,
            learning_rate=learning_rate,
            scheduler_kwargs=scheduler_kwargs,
            early_stopping_kwargs=early_stopping_kwargs,
            input_scaler=input_scaler,
            device=device,
        )
        self.sample_policy = sample_policy

    def forward(self, x: Tensor) -> Tensor:
        x = self.scale_input(x)
        code = self.encode(x)
        return self.decode(code)

    def train_epoch(self, data_loader: DataLoader) -> float | dict[str, float]:
        super().train_epoch(data_loader)

        train_loss = 0.0
        n_batch = len(data_loader)
        for batch in data_loader:
            weights: Tensor = batch["weights"]  # [B, P]
            states: Tensor = batch["states"]  # [B, S, in_f]
            actions: Tensor = batch["actions"]  # [B, S, out_f]

            self.optimizer.zero_grad()
            reconstructed_weights = self.forward(weights)
            reconstructed_actions = self.sample_policy.forward(
                states, reconstructed_weights
            )
            loss = self.loss_func(reconstructed_actions, actions)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()

        return train_loss / n_batch

    def test_epoch(self, data_loader: DataLoader) -> float | dict[str, float]:
        super().test_epoch(data_loader)

        val_loss = 0.0
        n_batch = len(data_loader)
        with torch.no_grad():
            for batch in data_loader:
                weights: torch.Tensor = batch["weights"]  # [B, P]
                states: torch.Tensor = batch["states"]  # [B, S, in_f]
                actions: torch.Tensor = batch["actions"]  # [B, S, out_f]

                reconstructed_weights = self.forward(weights)
                reconstructed_actions = self.sample_policy.forward(
                    states, reconstructed_weights
                )
                loss = self.loss_func(reconstructed_actions, actions)
                val_loss += loss.item()

        return val_loss / n_batch
