"""
Credits to https://github.com/CompVis/taming-transformers
"""

from dataclasses import dataclass
from typing import Any, Tuple

from einops import rearrange
import torch
import torch.nn as nn

from data import Batch
from .lpips import LPIPS
from .nets import Encoder, Decoder


@dataclass
class TokenizerEncoderOutput:
    z: torch.FloatTensor
    z_q: torch.FloatTensor
    tokens: torch.LongTensor


class Tokenizer(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.encoder = encoder
        self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, embed_dim, 1)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder.config.z_channels, 1)
        self.decoder = decoder
        self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
        self.lpips = LPIPS().eval() if with_lpips else None

    def __repr__(self) -> str:
        return "tokenizer"

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor]:
        outputs = self.encode(obs)
        decoder_input = outputs.z + (outputs.z_q - outputs.z).detach()
        rec = self.decode(decoder_input)
        return outputs.z, outputs.z_q, rec

    def compute_loss(self, batch: Batch):
        b, t, c, h, w = batch.obs.shape
        obs = batch.obs.reshape(b * t, c, h, w)
        z, z_q, rec = self(obs)

        # Codebook loss. Notes:
        # - beta position is different from taming and identical to original VQVAE paper
        # - VQVAE uses 0.25 by default
        beta = 1.0
        loss_commitment = (z.detach() - z_q).pow(2).mean() + beta * (z - z_q.detach()).pow(2).mean()

        loss_l1 = torch.abs(obs - rec).mean()
        loss_lpips = torch.mean(self.lpips(obs, rec))

        return loss_commitment + loss_l1 + loss_lpips, {'loss_commitment': loss_commitment.detach(), 'loss_l1': loss_l1.detach(), 'loss_lpips': loss_lpips.detach()}

    def encode(self, x: torch.Tensor) -> TokenizerEncoderOutput:
        shape = x.shape  # (..., C, H, W)
        x = x.view(-1, *shape[-3:])
        z = self.encoder(x)
        z = self.pre_quant_conv(z)
        b, e, h, w = z.shape
        z_flattened = rearrange(z, 'b e h w -> (b h w) e')
        dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())

        tokens = dist_to_embeddings.argmin(dim=-1)
        z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous()

        # Reshape to original
        z = z.reshape(*shape[:-3], *z.shape[1:])
        z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:])
        tokens = tokens.reshape(*shape[:-3], -1)

        return TokenizerEncoderOutput(z, z_q, tokens)

    def decode(self, z_q: torch.Tensor) -> torch.Tensor:
        shape = z_q.shape  # (..., E, h, w)
        z_q = z_q.view(-1, *shape[-3:])
        z_q = self.post_quant_conv(z_q)
        rec = self.decoder(z_q)
        rec = rec.reshape(*shape[:-3], *rec.shape[1:])
        return rec

    @torch.no_grad()
    def encode_decode(self, x: torch.Tensor) -> torch.Tensor:
        z_q = self.encode(x).z_q
        return self.decode(z_q)

