from typing import List, Optional, Type, Iterable
from types import SimpleNamespace
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils.parametrizations import spectral_norm

import gymnasium as gym


def ortho_init(module: nn.Module, gain: float = 1.0):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.orthogonal_(module.weight, gain=gain)
        if module.bias is not None:
            module.bias.data.fill_(0.)


def _get_norm_layer(channels: int, norm: Optional[str] = None):
    if norm == 'layer':
        return CLayerNorm(channels)
    elif norm == 'batch':
        return nn.BatchNorm2d(channels)
    elif norm == 'group':
        return nn.GroupNorm(8, channels)
    else:
        return nn.Identity()


def _apply_spec_norm(module: nn.Module, norm: Optional[str] = None):
    if norm == 'spectral':
        return spectral_norm(module)
    else:
        return module


class CLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


class MinAtarCNNBase(nn.Module):

    def __init__(
            self,
            observation_space: gym.spaces.Box,
            features_dim: int = 256,
            multiplier: int = 1,
            norm: Optional[str] = None,
    ):

        super().__init__()

        n_input_channels = observation_space.shape[0]
        bias = (norm != 'batch')
        self.seq = nn.Sequential(
            _apply_spec_norm(nn.Conv2d(n_input_channels, 32 * multiplier, kernel_size=3, bias=bias), norm),
            _get_norm_layer(32 * multiplier, norm),
            nn.ReLU(),
            _apply_spec_norm(nn.Conv2d(32 * multiplier, 32 * multiplier, kernel_size=3, bias=bias), norm),
            _get_norm_layer(32 * multiplier, norm),
            nn.ReLU(),
            nn.Flatten(),
            nn.LazyLinear(features_dim),
            nn.ReLU()
        )

        self.seq.eval()
        self.seq(torch.as_tensor(observation_space.sample()[None]).float())

        self.features_dim = features_dim

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.seq(observations.float())


class MinAtarCNN(MinAtarCNNBase):

    def __init__(self, observation_space: gym.spaces.Box):
        super().__init__(observation_space, 256, 1)


class MinAtarCNN4X(MinAtarCNNBase):

    def __init__(self, observation_space: gym.spaces.Box):
        super().__init__(observation_space, 1024, 4)


class Block(nn.Module):
    def __init__(self, channels: int, norm: Optional[str] = None, activation: Type[nn.Module] = nn.ReLU):
        super().__init__()

        bias = (norm != 'batch')
        self.seq = nn.Sequential(
            _get_norm_layer(channels, norm),
            activation(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=bias),
            _get_norm_layer(channels, norm),
            activation(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=bias),
        )

    def forward(self, x):
        return self.seq(x) + x


class MinAtarAutoEncoder(nn.Module):

    def __init__(
        self,
        observation_space: gym.spaces.Box,
        z_dim: int,
        norm: str = 'batch',
        activation: Type[nn.Module] = nn.ReLU,
    ):
        super().__init__()

        in_channels = observation_space.shape[1]

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3),
            Block(64, norm=norm, activation=activation),
            nn.Conv2d(64, 128, kernel_size=3),
            Block(128, norm=norm, activation=activation),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(128 + z_dim, 128, kernel_size=3, padding=1),
            Block(128, norm=norm, activation=activation),
            nn.ConvTranspose2d(128, 64, kernel_size=3),
            Block(64, norm=norm, activation=activation),
            nn.ConvTranspose2d(64, in_channels, kernel_size=3),
        )

        tmp = torch.as_tensor(observation_space.sample(), dtype=torch.float)
        self.encoder.eval()
        self.output_shape = self.encoder(tmp).shape[1:]

    def encode(self, obs: torch.Tensor):

        return self.encoder(obs.float())

    def decode(self, latent: torch.Tensor, z: torch.Tensor):

        z = z[:, :, None, None].tile(1, 1, *latent.shape[2:])

        return self.decoder(torch.cat([latent, z], dim=1))


class Conditioner(nn.Module):

    def __init__(
        self,
        n: int,
        latent_shape: torch.Size,
        depth: int = 2,
        norm: str = 'batch',
        activation: Type[nn.Module] = nn.ReLU
    ):
        super().__init__()

        self.n = n
        self.latent_shape = latent_shape

        self.seq = nn.Sequential(
            nn.Conv2d(n + latent_shape[0], latent_shape[0], kernel_size=3, padding=1),
            *[Block(latent_shape[0], norm=norm, activation=activation) for _ in range(depth)]
        )

    def forward(self, latent: torch.Tensor, condition: torch.Tensor):

        if condition.dim() == 1:
            condition = F.one_hot(condition, self.n).float()

        condition = condition[:, :, None, None].tile(1, 1, *self.latent_shape[1:])

        return self.seq(torch.cat([latent, condition], dim=1))


class TransitionModel(nn.Module):
    """
    A transition model based on conditional VQ-VAE
    """

    def __init__(
        self,
        env: gym.Env,
        z_dim: int = 16,
        quantizer: str = 'exact',
        norm: str = 'batch',
        activation: Type[nn.Module] = nn.SiLU,
        beta_commit: float = 0.,
        beta_entropy: float = 1e-4,
    ):
        super().__init__()

        self.beta_commit = beta_commit
        self.beta_entropy = beta_entropy

        self.observation_space = env.observation_space
        self.n_action = env.action_space[0].n
        self.z_dim = z_dim

        obs_shape = self.observation_space.shape
        img_dim = obs_shape[-1]
        if img_dim == 10:
            self.autoencoder = MinAtarAutoEncoder(self.observation_space, z_dim, norm=norm, activation=activation)
            self.encoder_shape = self.autoencoder.output_shape
        else:
            raise NotImplementedError

        self.action_conditioner = Conditioner(
            self.n_action,
            self.encoder_shape,
            depth=2,
            norm=norm,
            activation=activation
        )

        channels = self.encoder_shape[0]
        self.posterior = nn.Sequential(
            nn.Conv2d(channels * 2, channels, kernel_size=3, padding=1),
            Block(channels, norm=norm, activation=activation),
            Block(channels, norm=norm, activation=activation),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(channels, z_dim),
        )
        ortho_init(self.posterior[-1], gain=0.01)

        self.prior = nn.Sequential(
            Block(channels, norm=norm, activation=activation),
            Block(channels, norm=norm, activation=activation),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(channels, z_dim),
        )
        ortho_init(self.prior[-1], gain=0.01)

        self.register_buffer('code_usage', torch.ones(z_dim, dtype=torch.float32) / z_dim)
        self.register_buffer('learned_prior', torch.ones(z_dim, dtype=torch.float32) / z_dim)

        self.quantizer = quantizer
        assert quantizer in ['vq', 'gumbel_hard', 'gumbel_soft', 'exact']

    def forward(self, obs: torch.Tensor, obs_next: torch.Tensor, act: torch.Tensor, temperature: float = 1.):
        result = SimpleNamespace()

        latent_obs = self.autoencoder.encode(obs)
        latent_obs_next = self.autoencoder.encode(obs_next)

        latent_sa = self.action_conditioner(latent_obs, act)

        prior_logits = self.prior(latent_sa.detach())
        result.log_prior = F.log_softmax(prior_logits, dim=-1)
        result.prior = result.log_prior.exp()
        result.entropy_prior = -(result.log_prior * result.prior).sum(dim=-1).mean()

        self.learned_prior = 0.99 * self.learned_prior + 0.01 * result.prior.detach().mean(dim=0)

        latent_post = self.posterior(torch.cat([latent_sa, latent_obs_next], dim=1))
        if self.quantizer == 'vq':
            result.posterior = F.one_hot(latent_post.argmax(dim=-1), self.z_dim).float()
            z_soft = F.softmax(latent_post, dim=-1)
            z_sample = result.posterior + z_soft - z_soft.detach()
            result.loss_commit = (z_soft - result.posterior).square().mean()
            result.loss_prior = - (result.log_prior * result.posterior).sum(dim=-1).mean()
            self.code_usage = 0.99 * self.code_usage + 0.01 * z_sample.mean(dim=0).detach()
            obs_pred_logits = self.autoencoder.decode(latent_sa, z_sample)
            result.loss_recon = F.binary_cross_entropy_with_logits(obs_pred_logits, obs_next.float())
            result.loss_model = result.loss_recon + result.loss_prior + self.beta_commit * result.loss_commit
        elif 'gumbel' in self.quantizer:
            log_posterior = F.log_softmax(latent_post, dim=-1)
            result.posterior = log_posterior.exp()
            result.entropy_posterior = -(log_posterior * result.posterior).sum(dim=-1).mean()
            result.loss_kl = (
                result.posterior.detach() * (log_posterior.detach() - result.log_prior)).sum(dim=-1).mean()
            if 'hard' in self.quantizer:
                z_sample = F.gumbel_softmax(latent_post, tau=temperature, hard=True)
            else:
                z_sample = F.gumbel_softmax(latent_post, tau=temperature, hard=False)

            self.code_usage = 0.99 * self.code_usage + 0.01 * z_sample.mean(dim=0).detach()
            obs_pred_logits = self.autoencoder.decode(latent_sa, z_sample)
            result.loss_recon = F.binary_cross_entropy_with_logits(obs_pred_logits, obs_next.float())
            result.loss_model = result.loss_recon + result.loss_kl - self.beta_entropy * result.entropy_posterior

        elif 'exact' in self.quantizer:
            log_posterior = F.log_softmax(latent_post, dim=-1)
            result.posterior = log_posterior.exp()
            result.entropy_posterior = -(log_posterior * result.posterior).sum(dim=-1).mean()
            result.loss_kl = (
                result.posterior.detach() * (log_posterior.detach() - result.log_prior)).sum(dim=-1).mean()
            self.code_usage = 0.99 * self.code_usage + 0.01 * result.posterior.mean(dim=0).detach()

            obs_pred_logits = self.autoencoder.decode(
                torch.repeat_interleave(latent_sa, self.z_dim, dim=0),
                torch.eye(self.z_dim, dtype=torch.float, device=latent_sa.device).tile((len(latent_sa), 1))
            )
            loss_recon_z = F.binary_cross_entropy_with_logits(
                obs_pred_logits,
                obs_next.float().repeat_interleave(self.z_dim, dim=0),
                reduction='none',
            ).flatten(start_dim=1).mean(dim=-1).view(-1, self.z_dim)

            result.loss_recon = (loss_recon_z * result.posterior).sum(dim=-1).mean()
            result.loss_model = result.loss_recon + result.loss_kl - self.beta_entropy * result.entropy_posterior

        return result
