import sys
import torch

from einops import rearrange
from torch import nn

from torch.nn import functional as F

from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

# from diffusers.models.autoencoders.vae import Encoder, Decoder
# from taming.modules.diffusionmodules.model import Encoder, Decoder
from src.models.model_utils import get_model_params
from src.local_vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
from src.local_vector_quantize_pytorch.residual_vq import ResidualVQ

class VQVAE(nn.Module):
  def __init__(self, args):
    super().__init__()
    channels, resolution = get_model_params(args.dataset)
    embed_dim = args.codebook_dim
    n_embed = args.codebook_size
    decay = args.vq_decay
    self.args = args
    self.num_codes = n_embed
    self.cosine = (args.codebook == 'cosine')

    if args.arch == 'taming':
      from taming.modules.diffusionmodules.model import Encoder, Decoder

      z_channels = 256
      if args.f == 16:
        ch_mult = [1, 1, 2, 2, 4] # the internal channels will be ch * ch_mult at different resolutions
        attn_resolutions = [16] # apply attention at resolution 16
      elif args.f == 8:
        ch_mult = [1, 1, 2, 4]
        attn_resolutions = [32]
      elif args.f == 4:
        ch_mult = [1, 2, 4]
        attn_resolutions = []

      # For Encoder:
      # double_z: ignore
      # z_channels: the output channels of the encoder
      # resolution: the resolution of the input image
      # in_channels: the number of channels of the input image (usually 3)
      # out_ch: only used in the decoder. the number of channels of the output image (usually 3)
      # ch: the output channels of conv_in
      # ch_mult: the output channels of each block
      # num_res_blocks: the number of residual blocks in each block
      # attn_resolutions: if the resolution has been downsampled to this value, apply attention
      # dropout: dropout rate

      # the resolution will be /2, /4, /8, /16, /16 # (the last layer does not downsample)
      self.encoder = Encoder(double_z=False,
                            z_channels=z_channels,
                            resolution=resolution,
                            in_channels=channels,
                            out_ch=channels,
                            ch=128,
                            ch_mult=ch_mult,
                            num_res_blocks=2,
                            attn_resolutions=attn_resolutions,
                            dropout=0.0)
      self.decoder = Decoder(double_z=False,
                            z_channels=z_channels,
                            resolution=resolution,
                            in_channels=channels,
                            out_ch=channels,
                            ch=128,
                            ch_mult=ch_mult,
                            num_res_blocks=2,
                            attn_resolutions=attn_resolutions,
                            dropout=0.0)
    elif args.arch == 'enhancing':
      from src.modules.vit import ViTEncoder as Encoder
      from src.modules.vit import ViTDecoder as Decoder
      z_channels = 768
      enc_params = dict(
        image_size = resolution,
        patch_size = args.f,
        dim = z_channels,
        depth = 12,
        heads = 12,
        mlp_dim = 3072,
        channels = channels,
      )
      dec_params = dict(
        image_size = resolution,
        patch_size = args.f,
        dim = z_channels,
        depth = 12,
        heads = 12,
        mlp_dim = 3072,
        channels = channels,
      )
      self.encoder = Encoder(**enc_params)
      self.decoder = Decoder(**dec_params)

      if args.double_fp_shadow_decoder:
        from copy import deepcopy
        self.shadow_decoder = deepcopy(self.decoder)
        for p in self.shadow_decoder.parameters():
          p.requires_grad = False
        # NOTE: .eval() is not set
      else:
        self.shadow_decoder = None

    if args.codebook == 'cosine' or args.codebook == 'euclidean':
      # Note: ema_update=True and learnable_codebook=False, so will use ema updates to learn codebook vectors.
      # Normalize by emperical constants. We try to compare the total number of codes vs that of a reference setting
      # The total number of codes: args.batch_size * (resolution / args.f) ** 2
      # The reference setting: batch_size=256, resolution=32, f=16
      thresh_ema_normed = args.threshold_ema_dead_code * (args.batch_size * (resolution / args.f) ** 2) / (256 * (32 / 16) ** 2)
      # Then normalize by the codebook size. Compare with the reference setting (n_embed=1024)
      thresh_ema_normed = thresh_ema_normed * (1024) / n_embed

      # set num_quantizers
      if args.dim_per_quantizer == -1:
        num_quantizers = args.num_quantizers
      else:
        # the naming is a bit misleading
        # dim_per_quantizer is only used to determine the number of quantizers
        # - for both split VQ and residual VQ, the number of quantizers = embed_dim / dim_per_quantizer
        # the actual dim_per_quantizer depends on the type of VQ
        # - for residual VQ, the actual dim_per_quantizer = embed_dim
        # - for split VQ, the actual dim_per_quantizer = embed_dim / num_quantizers
        assert embed_dim % args.dim_per_quantizer == 0, f"embed_dim {embed_dim} is not divisible by dim_per_quantizer {args.dim_per_quantizer}"
        num_quantizers = embed_dim // args.dim_per_quantizer
      print(f"embed_dim: {embed_dim}, num_quantizers: {num_quantizers}, dim_per_quantizer: {args.dim_per_quantizer}")
      print(f"split_vq: {args.split_vq}, type={type(args.split_vq)}")

      if num_quantizers == -1:
        self.vq = VectorQuantize(dim=embed_dim, codebook_size=n_embed, commitment_weight=args.commit_weight, decay=decay,
                                accept_image_fmap=True, use_cosine_sim=(args.codebook == 'cosine'),
                                threshold_ema_dead_code=thresh_ema_normed,
                                stochastic_sample_codes=args.stochastic_sample_codes,
                                sample_codebook_temp=args.sample_codebook_temp,
                                commit_loss_p=args.commit_loss_p,
                                koleo=args.koleo,
        )
      else:
        self.vq = ResidualVQ(
          dim=embed_dim,
          codebook_size=n_embed,
          num_quantizers=num_quantizers,
          decay=decay,
          accept_image_fmap=True,
          use_cosine_sim=(args.codebook == 'cosine'),
          threshold_ema_dead_code=thresh_ema_normed,
          stochastic_sample_codes=args.stochastic_sample_codes,
          sample_codebook_temp=args.sample_codebook_temp,
          commit_loss_p=args.commit_loss_p,
          is_split_vq=args.split_vq,
        )
    elif args.codebook == 'gumbel':
      from src.modules.vit import GumbelQuantizer
      self.vq = GumbelQuantizer(
        embed_dim=embed_dim,
        n_embed=n_embed,
        temp_init=1.0, # TODO: anneal this
        use_norm=True, # True corresponds to "cosine" codebook
        use_residual=False,
        commitment_weight=args.commit_weight,
      )
    else:
      raise Exception(f'codebook: {args.codebook} is not supported.')

    # Set up projections into and out of codebook.
    # Modification: disable nn.Identity to ensure consistency between different settings.
    if args.codebook == 'cosine':
      self.pre_quant_proj = nn.Sequential(nn.Linear(z_channels, embed_dim),
                                          nn.LayerNorm(embed_dim)) # if embed_dim != z_channels else nn.Identity()
    else:
      self.pre_quant_proj = nn.Sequential(
        nn.Linear(z_channels, embed_dim)) # if embed_dim != z_channels else nn.Identity()
    self.post_quant_proj = nn.Linear(embed_dim, z_channels) # if embed_dim != z_channels else nn.Identity()

  @torch.no_grad()
  def log_encoder_eff_dim(self, x):
    # log the activation and the weight effective dimensions of the last layer of the encoder
    import torch.distributed as dist
    from src.eval_utils.pca import get_effective_dim
    gathered_x = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered_x, x)
    gathered_x = torch.cat(gathered_x, dim=0)
    gathered_x = rearrange(gathered_x, 'b c h w -> (b h w) c')
    encoder_act_eff_dim = get_effective_dim(gathered_x)

    w = self.pre_quant_proj[0].weight
    assert len(w.shape) == 2, w.shape
    if w.shape[1] > w.shape[0]:
      w = w.T # get_effective_dim expects the first dim >= second dim
    encoder_w_eff_dim = get_effective_dim(w)

    # save
    import src.train_utils.trainer as trainer
    trainer.to_save['encoder_act'] = gathered_x
    trainer.to_save['encoder_w'] = w

    # log
    from src.train_utils.wandb_utils import wandb_log
    wandb_log({'encoder_act_eff_dim': encoder_act_eff_dim, 'encoder_w_eff_dim': encoder_w_eff_dim})

  @torch.no_grad()
  def log_decoder_eff_dim(self, x):
    # log the activation and the weight effective dimensions of the first layer of the decoder
    import torch.distributed as dist
    from src.eval_utils.pca import get_effective_dim
    gathered_x = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered_x, x)
    gathered_x = torch.cat(gathered_x, dim=0)
    gathered_x = rearrange(gathered_x, 'b h w c -> (b h w) c')
    decoder_act_eff_dim = get_effective_dim(gathered_x)

    w = self.post_quant_proj.weight
    assert len(w.shape) == 2, w.shape
    if w.shape[1] > w.shape[0]:
      w = w.T # get_effective_dim expects the first dim >= second dim
    decoder_w_eff_dim = get_effective_dim(w)

    # save
    import src.train_utils.trainer as trainer
    trainer.to_save['decoder_act'] = gathered_x
    trainer.to_save['decoder_w'] = w

    # log
    from src.train_utils.wandb_utils import wandb_log
    wandb_log({'decoder_act_eff_dim': decoder_act_eff_dim, 'decoder_w_eff_dim': decoder_w_eff_dim})

  def get_codes(self, x):
    # Encode.
    x = self.encoder(x)
    x = rearrange(x, 'b c h w -> b h w c')
    x = self.pre_quant_proj(x)
    x = rearrange(x, 'b h w c -> b c h w')

    # VQ lookup.
    from src.modules.vit import GumbelQuantizer
    if not isinstance(self.vq, GumbelQuantizer):
      quantized, indices, _ = self.vq(x)
    else:
      _, _, indices = self.vq(rearrange(x, 'b c h w -> b h w c'))
    return indices

  def decode(self, indices):
    q = self.vq.get_codes_from_indices(indices)
    if self.cosine:
      q = q / torch.norm(q, dim=1, keepdim=True)

    # Decode.
    x = self.post_quant_proj(q)
    x = rearrange(x, 'b (h w) c -> b c h w', h=16)
    x = self.decoder(x)
    return x

  def encode_forward(self, x):
    # Encode.
    x = self.encoder(x)
    x = rearrange(x, 'b c h w -> b h w c')
    x = self.pre_quant_proj(x)
    x = rearrange(x, 'b h w c -> b c h w')

    # VQ lookup.
    from src.modules.vit import GumbelQuantizer
    if not isinstance(self.vq, GumbelQuantizer):
      quantized, indices, _ = self.vq(x)
    else:
      quantized, commit_loss, indices = self.vq(rearrange(x, 'b c h w -> b h w c'))
      quantized = rearrange(quantized, 'b h w c -> b c h w')
    return quantized

  def decoder_forward(self, q):
    if self.cosine:
      q = q / torch.norm(q, dim=1, keepdim=True)

    # Decode.
    x = rearrange(q, 'b c h w -> b h w c')
    x = self.post_quant_proj(x)
    x = rearrange(x, 'b h w c -> b c h w')
    x = self.decoder(x)
    return x

  @staticmethod
  def get_very_efficient_rotation(u, q, e):
    w = ((u + q) / torch.norm(u + q, dim=1, keepdim=True)).detach()
    e = e - 2 * torch.bmm(torch.bmm(e, w.unsqueeze(-1)), w.unsqueeze(1)) + 2 * torch.bmm(
      torch.bmm(e, u.unsqueeze(-1).detach()), q.unsqueeze(1).detach())
    return e

  def forward(self, x, vhp=False, return_rec=False, double_fp=False, rot=False, loss_scale=1.0):
    init_x = x
    # Encode.
    x = self.encoder(x)
    x = rearrange(x, 'b c h w -> b h w c')
    x = self.pre_quant_proj(x)
    x = rearrange(x, 'b h w c -> b c h w')

    # ViT-VQGAN codebook: "We also apply l2 normalization on the encoded latent variables ze(x)
    # and codebook latent variables e."
    if self.cosine:
      x = x / torch.norm(x, dim=1, keepdim=True)

    # record encoder_eff_dim
    try:
      self.log_encoder_eff_dim(x)
    except:
      pass

    import src.train_utils.trainer as trainer
    if self.args.vq_warmup <= trainer.global_step:
      e = x
      # VQ lookup.
      from src.modules.vit import GumbelQuantizer
      if not isinstance(self.vq, GumbelQuantizer):
        quantized, indices, commit_loss = self.vq(x)
      else:
        quantized, commit_loss, indices = self.vq(rearrange(x, 'b c h w -> b h w c'))
        quantized = rearrange(quantized, 'b h w c -> b c h w')
      q = quantized
    else:
      # No quantization.
      quantized = x
      commit_loss = torch.zeros(1, device=x.device)
      indices = None
      q = quantized


    # If using the rotation trick.
    if rot:
      b, c, h, w = x.shape
      x = rearrange(x, 'b c h w -> (b h w) c')
      quantized = rearrange(quantized, 'b c h w -> (b h w) c')
      pre_norm_q = self.get_very_efficient_rotation(x / (torch.norm(x, dim=1, keepdim=True) + 1e-6),
                                                    quantized / (torch.norm(quantized, dim=1, keepdim=True) + 1e-6),
                                                    x.unsqueeze(1)).squeeze()
      quantized = pre_norm_q * (
              torch.norm(quantized, dim=1, keepdim=True) / (torch.norm(x, dim=1, keepdim=True) + 1e-6)).detach()
      quantized = rearrange(quantized, '(b h w) c -> b c h w', b=b, h=h, w=w)

    # If doing a double forward pass to get exact gradients, **do not** use the STE to update the encoder.
    if double_fp:
      quantized = quantized.detach()  # Remove STE estimator here.
    if self.cosine:
      quantized = quantized / torch.norm(quantized, dim=1, keepdim=True)

    # Use codebook ema: no embed loss.
    # emb_loss = F.mse_loss(quantized, x.detach())

    # Decode.
    x = rearrange(quantized, 'b c h w -> b h w c')
    x = self.post_quant_proj(x)

    try:
      self.log_decoder_eff_dim(x)
    except:
      pass

    x = rearrange(x, 'b h w c -> b c h w')
    x = self.decoder(x)
    rec = x
    rec_loss = F.mse_loss(init_x, x)

    # If using Hessian approximation of the gradients...
    if vhp:
      vhp_fn = lambda z: loss_scale * F.mse_loss(init_x, self.decoder_forward(z))
      df2 = torch.autograd.functional.vhp(func=vhp_fn, inputs=q, v=e - q)[1]
      e.register_hook(lambda grad: grad + df2.to(torch.bfloat16))

    # If doing a double forward-pass to compute exact gradients...
    if double_fp:
      x = e
      if self.cosine:
        x = x / torch.norm(x, dim=1, keepdim=True)

      # Decode.
      x = rearrange(x, 'b c h w -> b h w c')
      x.register_hook(lambda grad: grad * self.args.scale_factor)  # Scale gradient back up.
      x = self.post_quant_proj(x)
      x = rearrange(x, 'b h w c -> b c h w')
      if self.shadow_decoder is not None:
        self.shadow_decoder.load_state_dict(self.decoder.state_dict())
        x = self.shadow_decoder(x)
      else:
        x = self.decoder(x)

      fp2_rec_loss = 1 / self.args.scale_factor * F.mse_loss(init_x, x)
      # rec_loss += fp2_rec_loss # mod: need to separate fp2_rec_loss from rec_loss for logging

    ret = {
      'rec_loss': rec_loss,
      'commit_loss': commit_loss,
    }
    if return_rec:
      ret['rec'] = rec
    if double_fp:
      ret['fp2_rec_loss'] = fp2_rec_loss
    return ret
