import torch.nn.functional as F
from gymnasium import spaces
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
from stable_baselines3.common.type_aliases import PyTorchObs

# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20

import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.sac.policies import Actor

device = th.device("cuda" if th.cuda.is_available() else "cpu")


class ProtoActor(Actor):
    prototypes: th.nn

    def __init__(self,
                 num_prototypes: int,
                 features_dim: int,
                 observation_space: spaces.Space,
                 action_space: spaces.Box,
                 net_arch: list[int],
                 latent_size: int,
                 features_extractor: nn.Module,
                 activation_fn: type[nn.Module] = nn.ReLU,
                 normalize_images: bool = True,
                 use_sde: bool = False,
                 log_std_init: float = -3,
                 full_std: bool = True,
                 use_expln: bool = False,
                 clip_mean: float = 2.0,
                 ):
        super().__init__(observation_space=observation_space,
                         action_space=action_space,
                         net_arch=net_arch,
                         features_extractor=features_extractor,
                         features_dim=latent_size,
                         activation_fn=activation_fn,
                         use_sde=use_sde,
                         log_std_init=log_std_init,
                         full_std=full_std,
                         use_expln=use_expln,
                         clip_mean=clip_mean,
                         normalize_images=normalize_images,

                         )
        self.use_sde = use_sde
        self.sde_features_extractor = None
        self.net_arch = net_arch
        self.features_dim = features_dim

        self.activation_fn = activation_fn
        self.log_std_init = log_std_init
        self.use_expln = use_expln
        self.full_std = full_std
        self.clip_mean = clip_mean

        self.num_prototypes = num_prototypes
        self.latent_size = latent_size
        self.action_dim = get_action_dim(self.action_space)

        last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim

        if self.use_sde:
            self.action_dist = StateDependentNoiseDistribution(self.action_dim, full_std=full_std, use_expln=use_expln,
                                                               learn_features=True, squash_output=True)
        else:
            self.action_dist = SquashedDiagGaussianDistribution(self.action_dim)  # type: ignore[assignment]

        self.prototypes = nn.Parameter(th.randn((num_prototypes, self.latent_size), dtype=th.float32),
                                       requires_grad=True)  # in pw-net: randn

        self.mean = nn.Parameter(th.randn((num_prototypes, self.action_dim), dtype=th.float32),
                                 requires_grad=True)  # in pw-net: randn
        self.log_stds = nn.Parameter(th.randn((num_prototypes, self.action_dim), dtype=th.float32),
                                     requires_grad=True)  # in pw-net: randn

        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.CrossEntropyLoss = nn.CrossEntropyLoss()
        self._tau = nn.Parameter(th.ones(num_prototypes) * 0.9, requires_grad=True)
        with th.no_grad():
            self._tau.copy_(th.logit(self._tau))

    @property
    def tau(self):
        return self.sigmoid(self._tau)
        # return self._tau

    def prototype_layer(self, x) -> th.Tensor:
        b_size = x.size(0)

        # [num_prototypes, latent_size]
        latent_protos = self.prototypes.flatten(start_dim=1)

        # [batch_size, latent_size, num_prototypes]
        p = latent_protos.T.view(1, self.latent_size, self.num_prototypes).tile(b_size, 1, 1).to(device)

        # [batch_size, latent_size, num_prototypes]
        c = x.view(b_size, self.latent_size, 1).tile(1, 1, self.num_prototypes).to(device)

        # Compute squared L2 distances
        l2s = ((c - p) ** 2).sum(dim=1)

        # Similarity function (Chen et al. 2019)
        similarity = th.log((l2s + 1.0) / (l2s + 1e-5))

        return similarity  # [batch_size, num_prototypes]

    def similarity(self, x) -> th.Tensor:
        # x size: (replay buffer size, embedding_dim)
        similarities = self.prototype_layer(x)

        # similarities  = batch size x n_prototypes
        scaled_similarities = similarities * self.tau

        # Only keep values close to the max (within delta)
        mask = scaled_similarities <= 0

        # Use large negative values to suppress softmax outputs at those positions
        adjusted_similarities = scaled_similarities.masked_fill(mask, float('-1e8'))
        # Softmax will now give zero at the -inf positions
        return self.softmax(adjusted_similarities)

    def get_proto_action_dist_params(self, obs: PyTorchObs):
        # Get similarity scores between obs and prototypes
        scaled_similarities = self.similarity(obs)  # [batch, num_prototypes]

        # Expand mean and log_std to match batch size
        mean_actions = self.mean.unsqueeze(0).expand(scaled_similarities.size(0), -1,
                                                     -1)  # [batch, num_prototypes, action_dim]
        log_std = self.log_stds.unsqueeze(0).expand(scaled_similarities.size(0), -1, -1)  # same shape

        # Compute weighted mean and std across prototypes
        weights = scaled_similarities.unsqueeze(-1)  # [batch, num_prototypes, 1]

        mean_actions = (weights * mean_actions).sum(dim=1)  # [batch, action_dim]
        log_std = (weights * log_std).sum(dim=1)  # [batch, action_dim]

        log_std = th.clamp(log_std, -20, 2)

        if self.use_sde:
            features = self.extract_features(obs, self.features_extractor)
            latent_pi = self.latent_pi(features)
            return mean_actions, self.log_std, dict(latent_sde=latent_pi)

        return mean_actions, log_std, {}

    def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]:

        mean_actions, log_std, kwargs = self.get_proto_action_dist_params(obs)

        # return action and associated log prob
        return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)

    def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:

        mean_actions, log_std, kwargs = self.get_proto_action_dist_params(obs)

        # Note: the action is squashed
        return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)

    def prototypes_update(self, obs: PyTorchObs, steps):
        self.eval()
        prototypes = self.prototypes.clone()
        actions = self.mean.clone()
        stds = self.log_stds.clone()
        tau = self._tau.clone()

        threshold = th.quantile(tau, q=0.5)
        mask = tau >= threshold

        good_prototypes = prototypes[mask]
        good_mean = actions[mask]
        good_stds = stds[mask]
        good_tau = tau[mask]

        num_to_replace = self.num_prototypes - good_prototypes.shape[0]

        # find the most distant prototypes in the replay buffer
        Inverse_similarity = 1 / (th.sqrt(self.prototype_layer(obs).sum(1)) + 1e-6)

        # Normalize both similarity and Q-values
        sim_norm = (Inverse_similarity - Inverse_similarity.mean()) / (Inverse_similarity.std() + 1e-6)

        composite_score = sim_norm

        _, indices = th.topk(composite_score, num_to_replace, largest=True)
        new_prototypes = obs[indices]

        with th.no_grad():
            # Get predicted action distribution parameters
            pred_mean, pred_log_stds, _ = self.get_proto_action_dist_params(new_prototypes)

            new_mean = pred_mean
            new_stds = pred_log_stds

            # Clamp stds if needed (optional, to avoid exploding variance)
            # new_stds = new_stds.clamp(min=-5.0, max=2.0)

            # Initialize tau and convert to logit space
            new_tau_vals = th.rand(num_to_replace, device=self.device)
            new_tau = th.logit(new_tau_vals)

        # Wrap all as Parameters
        new_mean = nn.Parameter(new_mean.detach(), requires_grad=True)
        new_stds = nn.Parameter(new_stds.detach(), requires_grad=True)
        new_tau = nn.Parameter(new_tau.detach(), requires_grad=True)

        # Concatenate updated prototypes
        updated_prototypes = nn.Parameter(th.cat((good_prototypes, new_prototypes), dim=0), requires_grad=True)
        updated_mean = nn.Parameter(th.cat((good_mean, new_mean), dim=0), requires_grad=True)
        updated_stds = nn.Parameter(th.cat((good_stds, new_stds), dim=0), requires_grad=True)
        updated_tau = nn.Parameter(th.cat((good_tau, new_tau), dim=0), requires_grad=True)

        # Apply updates
        with th.no_grad():
            self.prototypes.copy_(updated_prototypes)
            self.mean.copy_(updated_mean)
            self._tau.copy_(updated_tau)
            if not self.use_sde:
                self.log_stds.copy_(updated_stds)

        self.train()

    def orthogonal_loss(self, loss_type='fro', identity_weight=1.0):
        """
        Compute orthogonal loss on matrix W, encouraging W^T W ≈ I.

        Args:
            W (torch.Tensor): Shape (batch_size, embedding_dim)
            loss_type (str): Type of loss ('fro' for Frobenius norm or 'mse' for MSE).
            identity_weight (float): Scaling factor for the identity matrix.

        Returns:
            torch.Tensor: Orthogonal loss scalar
        """
        # Normalize rows (optional depending on your use case)
        W = self.prototypes

        W = F.normalize(W, p=2, dim=1)

        # Compute W^T W
        WT_W = th.matmul(W.T, W)

        # Create identity matrix
        I = th.eye(WT_W.size(0), device=W.device)

        if loss_type == 'fro':
            loss = th.norm(WT_W - identity_weight * I, p='fro')
        elif loss_type == 'mse':
            loss = F.mse_loss(WT_W, identity_weight * I)
        else:
            raise ValueError("Unsupported loss_type. Use 'fro' or 'mse'.")

        return loss
