import torch.nn as nn
import torch
from group_utils import LieParameterization, EulerParameterization
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class Encoder(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space, features_dim, channels_dim):
        super(Encoder, self).__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, channels_dim//2, kernel_size=7, padding=2, stride=4),
            nn.ReLU(),

            nn.Conv2d(channels_dim//2, channels_dim, kernel_size=5, padding=2, stride=2),
            nn.ReLU(),

            nn.Conv2d(channels_dim, channels_dim, kernel_size=3, padding=1, stride=2),
            nn.ReLU(),

            nn.Flatten()
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim))
        print('Feature Extractor defined')

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

class EncoderDecoupled(nn.Module):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, args, features_dim, channels_dim):
        super().__init__()
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = args.frame_stack * (1 if args.grayscale else 3)
        self.cnn = nn.Sequential(

            nn.Conv2d(n_input_channels, channels_dim//2, kernel_size=7, padding=2, stride=4),
            nn.ReLU(),

            nn.Conv2d(channels_dim//2, channels_dim, kernel_size=5, padding=2, stride=2),
            nn.ReLU(),

            nn.Conv2d(channels_dim, channels_dim, kernel_size=3, padding=1, stride=2),
            nn.ReLU(),

            nn.Flatten()
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.as_tensor(
                torch.randn(n_input_channels, args.img_w, args.img_h)[None]
            ).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim))
        print('Feature Extractor defined')

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(observations))

class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(self.shape)

    def __repr__(self):
        return 'View(%s)' % (', '.join(['%d' % x for x in self.shape]))

class NormalizeImageLayer(nn.Module):
    def __init__(self):
        super(NormalizeImageLayer, self).__init__()

    def forward(self, x):
        return x / 255.0

def get_encoder(args):
    IN_CHANNELS = args.image_channels
    hidden_dim = args.num_channels * (args.img_w // 8) * (args.img_h // 8)

    enc = nn.Sequential(
        NormalizeImageLayer(),

        nn.Conv2d(IN_CHANNELS, args.num_channels//4, kernel_size=3, padding=1, stride=2),
        nn.ReLU(),

        nn.Conv2d(args.num_channels//4, args.num_channels//2, kernel_size=3, padding=1, stride=2),
        nn.ReLU(),

        nn.Conv2d(args.num_channels//2, args.num_channels, kernel_size=3, padding=1, stride=2),
        nn.ReLU(),

        View(-1, hidden_dim),

        nn.Linear(hidden_dim, args.mlp_hidden_dim),
        nn.ReLU(),
        nn.Linear(args.mlp_hidden_dim, args.code_size, bias=True)
    )
    return enc

class ActionPredictorSO3(nn.Module):
    def __init__(self, args):
        super().__init__()
        in_dim = args.code_size // args.decompositions
        if args.param_type == 'Lie':
            out_dim = 3
            self.param = LieParameterization('SOn', 3, 1)
        elif args.param_type == 'Euler':
            out_dim = 3
            self.param = EulerParameterization('SOn', 3, 1)
        else:
            out_dim = 9
        self.mlp = nn.Sequential(
            nn.Linear(in_dim * 2, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim)
        )
        self.args = args

    def forward(self, x1, x2):
        batch_size = x1.shape[0]
        x = torch.cat([x1, x2], dim=-1)
        y = self.mlp(x)
        if self.args.param_type in ('Lie', 'Euler'):
            return self.param.get_group_rep(y.unsqueeze(1)).squeeze()
        else:
            return y.reshape(batch_size, 3, 3)


def get_decoder(args):
    OUT_CHANNELS = args.frame_stack * (1 if args.grayscale else 3)
    hidden_dim = args.num_channels * (args.img_w // 8) * (args.img_h // 8)

    dec = nn.Sequential(
        nn.Linear(args.code_size, args.mlp_hidden_dim),
        nn.ReLU(),
        nn.Linear(args.mlp_hidden_dim, hidden_dim),
        nn.ReLU(),
        View(-1, args.num_channels, args.img_w // 8, args.img_h // 8),
        nn.ConvTranspose2d(
            args.num_channels, args.num_channels//2, kernel_size=5, stride=2,
            padding=2, output_padding=1, bias=False
        ),
        nn.ReLU(),
        nn.ConvTranspose2d(
            args.num_channels//2, args.num_channels // 4, kernel_size=5,
            stride=2, padding=2, output_padding=1, bias=False
        ),
        nn.ReLU(),
        nn.ConvTranspose2d(
            args.num_channels // 4, OUT_CHANNELS, kernel_size=5,
            stride=2, padding=2, output_padding=1, bias=False
        )
    )
    return dec

def count_parameters(net):
    return sum(p.numel() for p in net.parameters() if p.requires_grad)

class TransModelNormal(nn.Module):
  def __init__(self, args, num_actions=9):
    super().__init__()
    self.MLP = nn.Sequential(
          nn.Linear(args.code_size + num_actions, 256),
          nn.ELU(),
          nn.Linear(256, 256),
          nn.ELU(),
          nn.Linear(256, args.code_size)
        )

  def transition(self, x, act):
    act = act.reshape(x.shape[0], -1)
    return self.MLP(torch.cat((x, act), dim=-1))

class TransModelGroup(nn.Module):
  def __init__(self, args):
    super().__init__()
    self.args = args
    self.mini_code_size = args.code_size // args.decompositions
    self.U = nn.Parameter(torch.normal(mean=0, std=0.1, size=(self.mini_code_size, self.mini_code_size)), requires_grad=True)

  def transition(self, x, act):
    next_x = []
    for i in range(self.args.decompositions):
        prod = self.U.T.unsqueeze(0).repeat(act.shape[0], 1, 1) @ act @ self.U.unsqueeze(0).repeat(act.shape[0], 1, 1)
        next_x.append((prod @ x[:, i * self.mini_code_size: (i + 1) * self.mini_code_size].unsqueeze(-1)).squeeze())
    return torch.cat(next_x, dim=-1)

  def group_action(self, x, act):
    next_x = []
    for i in range(self.args.decompositions):
        prod = torch.inverse(self.U).unsqueeze(0).repeat(act.shape[0], 1, 1) @ act @ self.U.unsqueeze(0).repeat(act.shape[0], 1, 1)
        next_x.append((prod @ x[:, i * self.mini_code_size: (i + 1) * self.mini_code_size].unsqueeze(-1)).squeeze())
    return torch.cat(next_x, dim=-1)