import gym
import torch
from einops import rearrange

from net.encoder.base import BaseEncoder
from net.memory.base import BaseMemory


def p_norm_penalty(param: torch.Tensor, p: float, weight: float) -> torch.Tensor:
    """p-norm (bridge) penalty"""
    return weight * (torch.linalg.vector_norm(param, ord=p, dim=-1) ** p)


class FeaturesExtractor(torch.nn.Module):
    def __init__(
            self,
            observation_space: gym.spaces.Box,
            latent_size: int,
            add_z_skip: bool,
            add_outer: bool,
            encoder_activation_penalty_norm_p: float,
            encoder_activation_penalty_weight: float,
            memory_activation_penalty_norm_p: float,
            memory_activation_penalty_weight: float,
            encoder_class: type[BaseEncoder],
            memory_class: type[BaseMemory]
    ):
        super().__init__()

        self.observation_space = observation_space
        self.latent_size = latent_size
        self.add_z_skip = add_z_skip
        self.add_outer = add_outer
        self.encoder_activation_penalty_norm_p = encoder_activation_penalty_norm_p
        self.encoder_activation_penalty_weight = encoder_activation_penalty_weight
        self.memory_activation_penalty_norm_p = memory_activation_penalty_norm_p
        self.memory_activation_penalty_weight = memory_activation_penalty_weight
        self.encoder_class = encoder_class
        self.memory_class = memory_class

        self.encoder = encoder_class(observation_space, latent_size)
        self.register_module("encoder", self.encoder)

        self.memory = memory_class(latent_size)
        self.register_module("memory", self.memory)

        self.output_size = self.memory.output_size

        if self.add_z_skip:
            self.output_size += self.latent_size

        if self.add_outer:
            self.output_size = self.output_size ** 2 + self.output_size

        self.z = None

    def forward(self, obs, h=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        self.z = self.encoder(obs)
        z = self.z * 10  # scale activation up
        ctx, h = self.memory(z, h)

        if self.add_z_skip:
            ctx = torch.cat([ctx, z], dim=-1)

        if self.add_outer:
            ctx_outer = rearrange(torch.einsum('...f, ...g -> ...fg', ctx, ctx), 'b t f g -> b t (f g)')
            ctx = torch.cat([ctx, ctx_outer], dim=-1)

        # encoder activation penalty
        z_p = self.encoder_activation_penalty_norm_p
        z_w = self.encoder_activation_penalty_weight
        encoder_activation_penalties = p_norm_penalty(z, z_p, z_w)
        # encoder_activation_penalties = reduce(encoder_activation_penalties, "b t -> b", "sum")
        # encoder_activation_penalty = reduce(encoder_activation_penalties, "b -> ", "mean")

        # memory activation penalty
        ctx_p = self.memory_activation_penalty_norm_p
        ctx_w = self.memory_activation_penalty_weight
        memory_activation_penalties = p_norm_penalty(ctx, ctx_p, ctx_w)
        # memory_activation_penalties = reduce(memory_activation_penalties, "b t -> b", "sum")
        # memory_activation_penalty = reduce(memory_activation_penalties, "b -> ", "mean")

        penalties = rearrange([encoder_activation_penalties, memory_activation_penalties], "p b t -> b t p")

        return ctx, h, penalties
