import os
import sys
from typing import Any

import numpy as np
import torch
from gymnasium import spaces
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.type_aliases import PyTorchObs
from stable_baselines3.td3.policies import TD3Policy
from torch import Tensor, nn

sys.path.insert(0, ".")

from compression_autoencoder.autoencoders.autoencoder import Autoencoder
from compression_autoencoder.policies.policy import Policy


class LatentActor(BasePolicy):
    """
    Actor network (policy) for TD3.

    :param observation_space: Observation space
    :param action_space: Action space
    :param features_extractor: Network to extract features
        (a CNN when using images, a nn.Flatten() layer otherwise)
    :param normalize_images: Whether to normalize images or not,
         dividing by 255.0 (True by default)
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Box,
        net_arch: list[int],  # noqa: ARG002
        features_extractor: nn.Module,
        features_dim: int,  # noqa: ARG002
        latent_dim: int,
        sample_policy: Policy,
        ae: Autoencoder,
        activation_fn: type[nn.Module] = nn.ReLU,  # noqa: ARG002
        normalize_images: bool = True,
        device: str | torch.device | None = None,
    ) -> None:
        super().__init__(
            observation_space,
            action_space,
            features_extractor=features_extractor,
            normalize_images=normalize_images,
            squash_output=True,
        )

        self.policy_ae_modules = {"sample_policy": sample_policy, "ae": ae}

        # Everything is frozen
        ae.freeze_decoder()
        # Extract the latent code
        self.latent_code = nn.Parameter(
            torch.zeros(1, latent_dim, device=device), requires_grad=True
        )
        self.policy_weights = self.policy_ae_modules["ae"].decode(self.latent_code)  # type: ignore
        self.changed = True

    def forward(self, obs: Tensor) -> Tensor:
        features = self.extract_features(obs, self.features_extractor)

        if torch.is_grad_enabled() or self.changed:
            self.policy_weights = self.policy_ae_modules["ae"].decode(self.latent_code)  # type: ignore
            self.changed = torch.is_grad_enabled()

        return self.policy_ae_modules["sample_policy"](
            features, self.policy_weights.expand(features.size(0), -1)
        )

    def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> Tensor:  # noqa: ARG002
        # Note: the deterministic deterministic parameter is ignored in the case of TD3.
        #   Predictions are always deterministic.
        return self(observation)


class LatentTD3Policy(TD3Policy):
    def __init__(
        self,
        *args,
        latent_dim: int,
        sample_policy: Policy,
        ae: Autoencoder,
        device: str | None = None,
        **kwargs,
    ) -> None:
        self.latent_dim = latent_dim
        self.policy_ae_modules = {"sample_policy": sample_policy, "ae": ae}
        self.target_device = device
        super().__init__(*args, **kwargs)

    def make_actor(
        self, features_extractor: BaseFeaturesExtractor | None = None
    ) -> Any:
        actor_kwargs = self._update_features_extractor(
            self.actor_kwargs, features_extractor
        )
        return LatentActor(
            **actor_kwargs,
            latent_dim=self.latent_dim,
            sample_policy=self.policy_ae_modules["sample_policy"],  # type: ignore
            ae=self.policy_ae_modules["ae"],  # type: ignore
            device=self.target_device,
        )


class LatentCodeCallback(BaseCallback):
    def __init__(self, verbose: int = 0, log_path: str = "") -> None:
        super().__init__(verbose)
        self.log_path = log_path

        self.latent_codes: list[np.ndarray] = []
        self.evaluation_timesteps: list[int] = []

    def _init_callback(self) -> None:
        if not isinstance(self.model.policy.actor, LatentActor):
            raise TypeError(
                "LatentActorCallback can only be used with LatentActor policies."
            )
        if self.log_path is not None:
            # Create folder if needed
            os.makedirs(self.log_path, exist_ok=True)

    def _on_step(self) -> bool:
        return super()._on_step()

    def _on_rollout_start(self) -> None:
        if self.model.policy.actor.changed:  # type: ignore
            self.latent_codes.append(
                self.model.policy.actor.latent_code.detach().cpu().numpy().copy()  # type: ignore
            )
            self.evaluation_timesteps.append(self.num_timesteps)

    def _on_training_end(self) -> None:
        self.latent_codes.append(
            self.model.policy.actor.latent_code.detach().cpu().numpy().copy()  # type: ignore
        )
        self.evaluation_timesteps.append(self.num_timesteps)
        np.savez(
            os.path.join(self.log_path, "latent_codes.npz"),
            timesteps=self.evaluation_timesteps,
            latent_codes=self.latent_codes,
        )
