import gym
import torch

from net.features_extractor import FeaturesExtractor


# FIXME: refactor to separate file
def p_norm_penalty(param: torch.Tensor, p: float, weight: float):
    """p-norm (bridge) penalty"""
    return weight * (torch.linalg.vector_norm(param, ord=p, dim=-1) ** p)


class DQN(torch.nn.Module):
    def __init__(
            self,
            action_space: gym.spaces.Discrete,
            weight_penalty_norm_p: float,
            weight_penalty_weight: float,
            features_extractor_class: type[FeaturesExtractor],
            features_extractor_kwargs: dict
    ):
        super().__init__()

        self.action_space = action_space
        self.weight_penalty_norm_p = weight_penalty_norm_p
        self.weight_penalty_weight = weight_penalty_weight

        self.features_extractor = features_extractor_class(**features_extractor_kwargs)
        self.register_module("features_extractor", self.features_extractor)

        self.ctx_size = self.features_extractor.output_size

        # self.net = torch.nn.Sequential(
        #     torch.nn.Linear(self.ctx_size, 12),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(12, action_space.n)
        # )
        self.net = torch.nn.Linear(self.ctx_size, action_space.n)

        self.z = None
        self.ctx = None

    def compute_weight_penalty(self) -> tuple[float, dict]:
        action_net_p = self.weight_penalty_norm_p
        action_net_w = self.weight_penalty_weight
        weight_penalty = p_norm_penalty(self.net.weight.flatten(), action_net_p, action_net_w)

        info = {
            "dqn_weights": weight_penalty.detach().cpu().item(),
        }
        return weight_penalty, info

    def forward(self, obs: torch.Tensor, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor]]:
        """
        Given a series of observation, returns action logits (e.g. action values, advantage values).
        """
        self.ctx, h, penalties = self.features_extractor(obs, h)
        self.z = self.features_extractor.z

        action_logits = self.net(self.ctx)

        return action_logits, h, penalties
