import gym
import torch

from net.features_extractor import FeaturesExtractor


def p_norm_penalty(param: torch.Tensor, p: float, weight: float) -> torch.Tensor:
    """p-norm (bridge) penalty"""
    return weight * (torch.linalg.vector_norm(param, ord=p, dim=-1) ** p)


# FIXME: split into action module and value module (as in https://spinningup.openai.com/en/latest/algorithms/ppo.html)
class ActorCritic(torch.nn.Module):
    def __init__(
            self,
            action_space: gym.spaces.Discrete,
            actor_weight_penalty_norm_p: float,
            actor_weight_penalty_weight: float,
            critic_weight_penalty_norm_p: float,
            critic_weight_penalty_weight: float,
            features_extractor_class: type[FeaturesExtractor],
            features_extractor_kwargs: dict
    ):
        super().__init__()

        self.action_space = action_space
        self.actor_weight_penalty_norm_p = actor_weight_penalty_norm_p
        self.actor_weight_penalty_weight = actor_weight_penalty_weight
        self.critic_weight_penalty_norm_p = critic_weight_penalty_norm_p
        self.critic_weight_penalty_weight = critic_weight_penalty_weight

        self.features_extractor = features_extractor_class(**features_extractor_kwargs)
        self.register_module("features_extractor", self.features_extractor)

        self.ctx_size = self.features_extractor.output_size

        # self.net = torch.nn.Sequential(
        #     torch.nn.Linear(self.ctx_size, 12),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(12, action_space.n)
        # )
        self.action_net = torch.nn.Linear(self.ctx_size, action_space.n)  # action value network
        self.state_net = torch.nn.Linear(self.ctx_size, 1)  # state value network

        self.z = None
        self.ctx = None

    def compute_weight_penalty(self) -> tuple[float, dict]:
        # actor
        action_net_p = self.actor_weight_penalty_norm_p
        action_net_w = self.actor_weight_penalty_weight
        actor_weight_penalty = p_norm_penalty(self.action_net.weight.flatten(), action_net_p, action_net_w)

        # critic
        state_net_p = self.critic_weight_penalty_norm_p
        state_net_w = self.critic_weight_penalty_weight
        critic_weight_penalty = p_norm_penalty(self.state_net.weight.flatten(), state_net_p, state_net_w)

        total_weight_penalty = actor_weight_penalty + critic_weight_penalty
        info = {
            "actor_weights": actor_weight_penalty.detach().cpu().item(),
            "critic_weights": critic_weight_penalty.detach().cpu().item()
        }
        return total_weight_penalty, info

    def forward(self, obs: torch.Tensor, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Given a series of observation, returns action logits (e.g. action values, advantage values) and state logits.
        """
        self.ctx, h, penalties = self.features_extractor(obs, h)
        self.z = self.features_extractor.z

        action_logits = self.action_net(self.ctx)  # actor
        state_logits = self.state_net(self.ctx)  # critic
        return action_logits, state_logits, h, penalties
