import numpy as np
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange
from tqdm import tqdm, trange
from PIL import Image
from src.modules import STTransformer


class VectorQuantizer(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
    avoids costly matrix multiplications and allows for post-hoc remapping of indices.
    """
    # NOTE: due to a bug the beta term was applied to the wrong term. for
    # backwards compatibility we use the buggy version by default, but you can
    # specify legacy=False to fix it.
    def __init__(self, n_e, e_dim, beta):
        super().__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        self.re_embed = n_e

    def forward(self, z):
        z_flattened = z.reshape(-1, self.e_dim)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).reshape(z.shape)

        # compute loss for embedding
        loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
                torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        return z_q, loss, min_encoding_indices


class NormVectorQuantizer(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
    avoids costly matrix multiplications and allows for post-hoc remapping of indices.
    """
    # NOTE: due to a bug the beta term was applied to the wrong term. for
    # backwards compatibility we use the buggy version by default, but you can
    # specify legacy=False to fix it.
    def __init__(self, n_e, e_dim, beta):
        super().__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.norm = lambda x: F.normalize(x, dim=-1)
        self.embedding.weight.data.normal_()

        self.re_embed = n_e

    def forward(self, z):
        z_flattened_norm = self.norm(z.reshape(-1, self.e_dim))
        embedding_norm = self.norm(self.embedding.weight)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened_norm ** 2, dim=1, keepdim=True) + \
            torch.sum(embedding_norm**2, dim=1) - 2 * \
            torch.einsum('bd,nd->bn', z_flattened_norm, embedding_norm)

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).reshape(z.shape)
        z_qnorm, z_norm = self.norm(z_q), self.norm(z)

        if self.training:
            # compute loss for embedding
            loss = self.beta * torch.mean((z_qnorm.detach()-z_norm)**2) + \
                    torch.mean((z_qnorm - z_norm.detach()) ** 2)

            # preserve gradients
            z_qnorm = z_norm + (z_qnorm - z_norm).detach()
        else:
            loss = 0
            z_qnorm = z_qnorm

        return z_qnorm, loss, min_encoding_indices


class VideoTokenizer(nn.Module):
    def __init__(self, encoder_config, decoder_config, cb_config, img_size, patch_size, seq_len):
        super().__init__()
        encoder_config.n_tokens_per_frame = (img_size // patch_size) ** 2
        encoder_config.block_size = seq_len * encoder_config.n_tokens_per_frame
        encoder_config.vocab_size = None
        decoder_config.n_tokens_per_frame = (img_size // patch_size) ** 2
        decoder_config.block_size = seq_len * decoder_config.n_tokens_per_frame
        decoder_config.vocab_size = None
        self.encoder = STTransformer(encoder_config)
        self.decoder = STTransformer(decoder_config)
        self.to_patch_embed = nn.Sequential(
            nn.Conv2d(3, encoder_config.n_embd, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        self.encoder_output_proj = nn.Linear(encoder_config.n_embd, cb_config.embed_dim, bias=encoder_config.bias)
        self.decoder_input_proj = nn.Linear(cb_config.embed_dim, decoder_config.n_embd, bias=decoder_config.bias)
        self.to_pixel = nn.Sequential(
            Rearrange('b (h w) c -> b c h w', h=img_size//patch_size, w=img_size//patch_size),
            nn.ConvTranspose2d(decoder_config.n_embd, 3, kernel_size=patch_size, stride=patch_size),
        )
        self.quantizer = NormVectorQuantizer(cb_config.n_embd, cb_config.embed_dim, cb_config.beta)

        self.img_size = img_size
        self.patch_size = patch_size
        self.seq_len = seq_len

    def encode(self, x):
        # x shape (B T C H W)
        B, T, _, _, _ = x.size()
        x = self.to_patch_embed(x.reshape(B*T, *x.shape[-3:])).contiguous()
        _, HW, C = x.size()
        x = x.reshape(B, T*HW, C)
        h = self.encoder_output_proj(self.encoder(x))
        quant, embed_loss, info = self.quantizer(h)
        return quant, embed_loss, info
    
    def decode(self, quant):
        B = quant.shape[0]
        dec = self.decoder(self.decoder_input_proj(quant)) # (B, T*H*W, C)
        dec = dec.reshape(-1, (self.img_size//self.patch_size)**2, dec.shape[-1]) # (B*T, H*W, C)
        dec = self.to_pixel(dec).contiguous() # (B*T, C, H, W)
        dec = dec.reshape(B, -1, *dec.shape[-3:]) # (B, T, C, H, W)
        return dec
    
    def forward(self, batch):
        x = batch[0]
        quant, embed_loss, _ = self.encode(x)
        dec = self.decode(quant)
        return dec, embed_loss
    
    def criterion(self, batch, output):
        x = batch[0]
        xrec, qloss = output

        rec_loss = torch.abs(x - xrec).mean()
        codebook_loss = qloss.mean()
        loss = rec_loss + codebook_loss

        return loss, {
            'loss': loss.item(),
            'rec_loss': rec_loss.item(),
            'codebook_loss': codebook_loss.item()
        }
