# Modified from:
#   taming-transformers: https://github.com/CompVis/taming-transformers
#   maskgit: https://github.com/google-research/maskgit
from dataclasses import dataclass, field
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin

if __name__ == '__main__':
    import sys
    sys.path.append('.')
from modelling.modules import Encoder, Decoder, TimmViTEncoder, TimmViTDecoder
from modelling.quantizers.vq import VectorQuantizer
from modelling.quantizers.kl import DiagonalGaussianDistribution
from modelling.quantizers.softvq import SoftVectorQuantizer
from losses.bit_estimator import BitEstimator, GaussianBitEstimator

from timm import create_model
from einops import rearrange
from torchvision import transforms


def mean_flat(x):
    """
    Take the mean over all non-batch dimensions.
    """
    return torch.mean(x, dim=list(range(1, len(x.size()))))


def build_mlp(hidden_size, projector_dim, z_dim):
    return nn.Sequential(
                nn.Linear(hidden_size, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, z_dim),
            )


@dataclass
class ModelArgs:
    image_size: int = 256
    base_image_size: int = 256
    num_frames: int = 1
    variable_num_frames: bool = False
    
    codebook_size: int = 16384
    codebook_embed_dim: int = 8
    codebook_l2_norm: bool = True
    codebook_show_usage: bool = True
    commit_loss_beta: float = 0.25
    entropy_loss_ratio: float = 0.0
    vq_loss_ratio: float = 1.0 # for soft vq
    kl_loss_weight: float = 0.000001
    tau: float = 0.1
    num_codebooks: int = 1
    
    encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
    decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
    z_channels: int = 256
    dropout_p: float = 0.0

    enc_type: str = 'cnn'
    dec_type: str = 'cnn'
    encoder_model: str = 'llamagen_encoder'
    decoder_model: str = 'llamagen_decoder'
    num_latent_tokens: int = 256
    to_pixel: str = 'linear'
    
    # for pre-trained models
    enc_tuning_method: str = 'full'
    dec_tuning_method: str = 'full'
    enc_pretrained: bool = True
    dec_pretrained: bool = False 
    
    # for vit 
    enc_patch_size: int = 16
    dec_patch_size: int = 16
    t_patch_size: int = 1
    enc_drop_path_rate: float = 0.0
    dec_drop_path_rate: float = 0.0

    # encoder token drop
    enc_token_drop: float = 0.0
    enc_token_drop_max: float = 0.6

    # latents token drop
    latent_token_drop_max: float = 0.0
    
    # decoder cls token
    dec_cls_token: bool = True
    
    # PE
    use_ape: bool = True 
    # rope
    use_rope: bool = False
    rope_mixed: bool = False
    rope_dim: int = None
    rope_heads: int = None
    rope_layers: int = None
    rope_sbm: bool = False
    rope_theta: float = 10.0
    rope_theta_t: float = 100.0
    
    # repa for vit
    repa: bool = False
    repa_patch_size: int = 16
    repa_model: str = 'vit_base_patch16_224'
    repa_proj_dim: int = 2048
    repa_loss_weight: float = 0.1
    repa_align: str = 'global'
    
    vq_mean: float = 0.0
    vq_std: float = 1.0

    # Channel-wise attention
    enc_channel_attn: bool = False
    dec_channel_attn: bool = False

    # attention with PE
    enc_attn_pe: bool = False
    dec_attn_pe: bool = False
    enc_attn_latent_rope: bool = False
    dec_attn_latent_rope: bool = False

    # use different mask tokens in decoder
    dec_seperate_mask_token: bool = False

    # rate loss
    rate_loss_weight: int = 0

    # aux decoder model
    aux_dec_model: str = 'vit_tinytiny_patch14_dinov2_movq'
    aux_loss_mask: bool = False
    aux_dec_cls_token: bool = True
    aux_hog_dec: bool = False
    aux_dino_dec: bool = False
    aux_clip_dec: bool = False

    use_coord_mlp: bool = False

class VQModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ModelArgs, 
                tags=["arxiv:2412.10958", "image-generation", "32 tokens", "SoftVQ-VAE"], 
                repo_url="https://github.com/Hhhhhhao/continuous_tokenizer", 
                license="apache-2.0"):
        super().__init__()
        self.config = config
        self.vq_mean = config.vq_mean
        self.vq_std = config.vq_std
        self.num_latent_tokens = config.num_latent_tokens
        self.codebook_embed_dim = config.codebook_embed_dim

        self.init_repa(config)
        
        if config.enc_type == 'cnn':
            if config.encoder_model == 'llamagen_encoder':
                self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
            else:
                raise NotImplementedError
            self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
        elif config.enc_type == 'vit':
            self.encoder = TimmViTEncoder(
                in_channels=3, num_latent_tokens=config.num_latent_tokens,
                model_name=config.encoder_model,  # 'vit_small_patch14_dinov2.lvd142m', #'vit_base_patch14_dinov2.lvd142m',  #
                model_kwargs={
                    'img_size': config.image_size, 'num_frames': config.num_frames, 
                    'patch_size': config.enc_patch_size, 't_patch_size': config.t_patch_size, 
                    'drop_path_rate': config.enc_drop_path_rate
                },
                pretrained=config.enc_pretrained,
                tuning_method=config.enc_tuning_method,
                tuning_kwargs={'r': 8},
                use_ape=config.use_ape, use_rope=config.use_rope, rope_mixed=config.rope_mixed, 
                rope_dim=config.rope_dim, rope_sbm=config.rope_sbm,
                rope_heads=config.rope_heads, rope_layers=config.rope_layers,
                rope_theta=config.rope_theta, rope_theta_t=config.rope_theta_t,
                token_drop=config.enc_token_drop,
                token_drop_max=config.enc_token_drop_max,
                base_img_size=config.base_image_size,
                use_channel_attn=config.enc_channel_attn,
                use_attn_pe=config.enc_attn_pe,
                use_attn_latent_rope=config.enc_attn_latent_rope,
                variable_num_frames=config.variable_num_frames,
            )
            self.quant_conv = nn.Linear(self.encoder.embed_dim, config.codebook_embed_dim)
            
        
        if config.dec_type == 'cnn':
            if config.decoder_model == 'llamagen_decoder':
                self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
            else:
                raise NotImplementedError
            self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
        elif config.dec_type == 'vit':
            self.decoder = TimmViTDecoder(
                in_channels=3, num_latent_tokens=config.num_latent_tokens,
                model_name=config.decoder_model,  # 'vit_small_patch14_dinov2.lvd142m', #'vit_base_patch14_dinov2.lvd142m',  #
                model_kwargs={
                    'img_size': config.image_size, 'num_frames': config.num_frames, 
                    'patch_size': config.dec_patch_size, 't_patch_size': config.t_patch_size, 
                    'drop_path_rate': config.dec_drop_path_rate, 'latent_dim': config.codebook_embed_dim
                },
                pretrained=config.dec_pretrained,
                tuning_method=config.dec_tuning_method,
                tuning_kwargs={'r': 8},
                use_ape=config.use_ape, use_rope=config.use_rope, rope_mixed=config.rope_mixed, 
                rope_dim=config.rope_dim, rope_sbm=config.rope_sbm,
                rope_heads=config.rope_heads, rope_layers=config.rope_layers,
                rope_theta=config.rope_theta, rope_theta_t=config.rope_theta_t,
                cls_token=config.dec_cls_token,
                codebook_embed_dim=config.codebook_embed_dim,
                to_pixel=config.to_pixel,
                base_img_size=config.base_image_size,
                use_channel_attn=config.dec_channel_attn,
                seperate_mask_token=config.dec_seperate_mask_token,
                use_attn_pe=config.dec_attn_pe,
                use_attn_latent_rope=config.dec_attn_latent_rope,
                variable_num_frames=config.variable_num_frames,
                use_coord_mlp=config.use_coord_mlp,
            )
            self.post_quant_conv = nn.Linear(config.codebook_embed_dim, self.decoder.embed_dim)
        # check movq
        if 'movq' in config.decoder_model:
            self.use_movq = True 
        else:
            self.use_movq = False
        
        
        self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim, 
                                        config.commit_loss_beta, config.entropy_loss_ratio,
                                        config.codebook_l2_norm, config.codebook_show_usage)
        
        if config.enc_tuning_method == 'frozen':
            for param in self.quant_conv.parameters():
                param.requires_grad = False
            for param in self.quantize.parameters():
                param.requires_grad = False

    def init_repa(self, config):
        self.repa = config.repa
        self.repa_loss_weight = config.repa_loss_weight
        self.repa_align = config.repa_align
        if config.repa and config.enc_type == 'vit':
            self.repa_model = create_model(config.repa_model, pretrained=True, img_size=config.image_size, patch_size=config.repa_patch_size)
            for param in self.repa_model.parameters():
                param.requires_grad = False
            self.repa_model.eval()
            repa_z_dim = self.repa_model.embed_dim
            self.repa_z_dim = repa_z_dim
            self.projection = build_mlp(config.codebook_embed_dim, config.repa_proj_dim, repa_z_dim)
            from modelling.lpips.lpips_timm import Normalize, Denormalize
            self.de_scale = Denormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            self.scale = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        else:
            repa_z_dim = None

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        emb_loss = {
            'vq_loss': emb_loss[0],
            'commit_loss': emb_loss[1],
            'entropy_loss': emb_loss[2],
            'codebook_usage': emb_loss[3],
        }
        
        if self.repa and self.training:
            # get z from repa_encoder
            rescale_x = self.scale(self.de_scale(x))
            z = self.repa_model.forward_features(rescale_x)[:, self.repa_model.num_prefix_tokens:]

            # taking average over spatial dimension
            if self.repa_align == 'global':
                z = z.mean(dim=1)
                z_hat = quant.mean(dim=1)
                # calculate repa loss
                z_hat = self.projection(z_hat)
            elif self.repa_align == 'avg_1d':
                z = F.adaptive_avg_pool1d(z.permute(0, 2, 1), quant.shape[1]).permute(0, 2, 1)
                z_hat = quant
                z_hat = self.projection(z_hat)
            elif self.repa_align == 'avg_1d_shuffle':
                # shuffle the length dimension of z and avg
                indices = torch.randperm(z.shape[1])
                z = F.adaptive_avg_pool1d(z[:, indices, :].permute(0, 2, 1) , quant.shape[1]).permute(0, 2, 1)
                z_hat = quant
                z_hat = self.projection(z_hat)
            elif self.repa_align == 'repeat':
                z_hat = self.projection(quant)
                b, l, d = z_hat.shape
                z_hat = z_hat.unsqueeze(2).expand(-1, -1, z.size(1) // l, -1).reshape(b, -1, d)
            

            z = F.normalize(z, dim=-1)
            z_hat = F.normalize(z_hat, dim=-1)
            proj_loss = mean_flat(-(z * z_hat).sum(dim=-1))
            proj_loss = proj_loss.mean()
            proj_loss *= self.repa_loss_weight
            
            # emb_loss += (proj_loss,)
            emb_loss['repa_loss'] = proj_loss
        
        return quant, emb_loss, info

    def decode(self, quant, x=None, h=None, w=None):
        tmp_quant = quant 
        quant = self.post_quant_conv(quant)
        if self.use_movq:
            dec = self.decoder(quant, tmp_quant, h, w)
        else:
            dec = self.decoder(quant, None, h, w)
        return dec

    def decode_code(self, code_b, shape=None, channel_first=True):
        quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
        dec = self.decode(quant_b)
        return dec

    def forward(self, input):
        h, w = input.shape[-2:]
        quant, diff, info = self.encode(input)
        self.quant = quant
        dec = self.decode(quant, x=input, h=h, w=w)
        return dec, diff, info


class SoftVQModel(VQModel, PyTorchModelHubMixin):
    def __init__(self, config: ModelArgs, 
                tags=["arxiv:2412.10958", "image-generation", "32 tokens", "SoftVQ-VAE"], 
                repo_url="https://github.com/Hhhhhhao/continuous_tokenizer", 
                license="apache-2.0"):
        super().__init__(config)
        self.quantize = SoftVectorQuantizer(config.codebook_size, config.codebook_embed_dim, 
                                            config.entropy_loss_ratio, 
                                            config.tau,                                   
                                            config.num_codebooks,
                                            config.codebook_l2_norm, config.codebook_show_usage)


class KLModel(VQModel):
    def __init__(self, config: ModelArgs):
        super().__init__(config)
        self.kl_loss_weight = config.kl_loss_weight
        self.quantize = None
        
        if config.enc_type == 'cnn':
            self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim * 2, 1)
        elif config.enc_type == 'vit':
            self.quant_conv = nn.Linear(self.encoder.embed_dim, config.codebook_embed_dim * 2)
        

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        # quant, emb_loss, info = self.quantize(h)
        h_posterior = DiagonalGaussianDistribution(h)
        return h_posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def decode_code(self, posterior, shape=None):
        z = posterior.sample()
        dec = self.decode(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        # compute kl loss
        kl_loss = posterior.kl()
        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
        diff = (kl_loss * self.kl_loss_weight, torch.tensor(0.), torch.tensor(0.), torch.tensor(0.))
        return dec, diff, None

import math
class HOGGenerator(nn.Module):
    """Generate HOG feature for images.

    This module is used in MaskFeat to generate HOG feature. The code is
    modified from file `slowfast/models/operators.py
    <https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
    Here is the link of `HOG wikipedia
    <https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.

    Args:
        nbins (int): Number of bin. Defaults to 9.
        pool (float): Number of cell. Defaults to 8.
        gaussian_window (int): Size of gaussian kernel. Defaults to 16.
    """

    def __init__(self,
                 nbins: int = 9,
                 pool: int = 8,
                 gaussian_window: int = 16) -> None:
        super().__init__()
        self.nbins = nbins
        self.pool = pool
        self.pi = math.pi
        weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
        weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous()
        weight_y = weight_x.transpose(2, 3).contiguous()
        self.register_buffer('weight_x', weight_x)
        self.register_buffer('weight_y', weight_y)

        self.gaussian_window = gaussian_window
        if gaussian_window:
            gaussian_kernel = self.get_gaussian_kernel(gaussian_window,
                                                       gaussian_window // 2)
            self.register_buffer('gaussian_kernel', gaussian_kernel)

    def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor:
        """Returns a 2D Gaussian kernel array."""

        def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor:
            n = torch.arange(0, kernlen).float()
            n -= n.mean()
            n /= std
            w = torch.exp(-0.5 * n**2)
            return w

        kernel_1d = _gaussian_fn(kernlen, std)
        kernel_2d = kernel_1d[:, None] * kernel_1d[None, :]
        return kernel_2d / kernel_2d.sum()

    def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
        """Reshape HOG Features for output."""
        hog_feat = hog_feat.flatten(1, 2)
        self.unfold_size = hog_feat.shape[-1] // 16
        hog_feat = hog_feat.permute(0, 2, 3, 1)
        # print(hog_feat.shape)
        hog_feat = hog_feat.unfold(1, self.unfold_size,
                                   self.unfold_size).unfold(
                                       2, self.unfold_size, self.unfold_size)
        hog_feat = hog_feat.flatten(1, 2).flatten(2)
        return hog_feat

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Generate hog feature for each batch images.

        Args:
            x (torch.Tensor): Input images of shape (N, 3, H, W).

        Returns:
            torch.Tensor: Hog features.
        """
        # input is RGB image with shape [B 3 H W]
        self.h, self.w = x.size(-2), x.size(-1)
        x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
        gx_rgb = F.conv2d(
            x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
        gy_rgb = F.conv2d(
            x, self.weight_y, bias=None, stride=1, padding=0, groups=3)
        norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1)
        phase = torch.atan2(gx_rgb, gy_rgb)
        phase = phase / self.pi * self.nbins  # [-9, 9]

        b, c, h, w = norm_rgb.shape
        out = torch.zeros((b, c, self.nbins, h, w),
                          dtype=x.dtype, #torch.float,
                          device=x.device)
        phase = phase.view(b, c, 1, h, w)
        norm_rgb = norm_rgb.view(b, c, 1, h, w)
        if self.gaussian_window:
            if h != self.gaussian_window:
                assert h % self.gaussian_window == 0, 'h {} gw {}'.format(
                    h, self.gaussian_window)
                repeat_rate = h // self.gaussian_window
                temp_gaussian_kernel = self.gaussian_kernel.repeat(
                    [repeat_rate, repeat_rate])
            else:
                temp_gaussian_kernel = self.gaussian_kernel
            norm_rgb *= temp_gaussian_kernel

        out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb)

        out = out.unfold(3, self.pool, self.pool)
        out = out.unfold(4, self.pool, self.pool)
        out = out.sum(dim=[-1, -2])

        self.out = F.normalize(out, p=2, dim=2)

        return self._reshape(self.out)

    def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray:
        """Generate HOG image according to HOG features."""
        assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \
            'Check the input batch size and the channcel number, only support'\
            '"batch_size = 1".'
        hog_image = np.zeros([self.h, self.w])
        cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu())
        cell_width = self.pool / 2
        max_mag = np.array(cell_gradient).max()
        angle_gap = 360 / self.nbins

        for x in range(cell_gradient.shape[1]):
            for y in range(cell_gradient.shape[2]):
                cell_grad = cell_gradient[:, x, y]
                cell_grad /= max_mag
                angle = 0
                for magnitude in cell_grad:
                    angle_radian = math.radians(angle)
                    x1 = int(x * self.pool +
                             magnitude * cell_width * math.cos(angle_radian))
                    y1 = int(y * self.pool +
                             magnitude * cell_width * math.sin(angle_radian))
                    x2 = int(x * self.pool -
                             magnitude * cell_width * math.cos(angle_radian))
                    y2 = int(y * self.pool -
                             magnitude * cell_width * math.sin(angle_radian))
                    magnitude = 0 if magnitude < 0 else magnitude
                    cv2.line(hog_image, (y1, x1), (y2, x2),
                             int(255 * math.sqrt(magnitude)))
                    angle += angle_gap
        return hog_image
        
import scipy.stats as stats
from modelling.lpips.lpips_timm import Normalize, Denormalize
class AEModel(VQModel):
    def __init__(self, config: ModelArgs,
                tags=["arxiv:xxx", "image-generation", "1d-tokenizer", "128 tokens", "MAETok"], 
                repo_url="https://github.com/Hhhhhhao/continuous_tokenizer", 
                license="apache-2.0"):
        config.repa = config.repa or config.aux_dino_dec
        super().__init__(config)
        self.quantize = None
        self.de_scale = Denormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.scale = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

        # aux decoder
        self.aux_loss_mask = config.aux_loss_mask
        self.aux_hog_decoder = config.aux_hog_dec
        self.aux_dino_decoder = config.aux_dino_dec
        self.aux_clip_decoder = config.aux_clip_dec
        if self.aux_hog_decoder:
            self.init_aux_hog(config)
        if self.aux_dino_decoder:
            self.init_aux_dino(config)
        if self.aux_clip_decoder:
            self.init_aux_clip(config)

        # latents drop
        self.latent_drop = config.latent_token_drop_max > 0
        if self.latent_drop:
            self.mask_latent_ratio_generator = stats.uniform(loc=0, scale=config.latent_token_drop_max)
            self.mask_latent = nn.Parameter(torch.zeros(1, 1, config.codebook_embed_dim))
            nn.init.normal_(self.mask_latent, std=.02)

        # rate loss
        self.rate_loss_weight = config.rate_loss_weight
        if self.rate_loss_weight > 0:
            self.bit_estimator = BitEstimator(config.num_latent_tokens, config.codebook_embed_dim)
            # self.bit_estimator = GaussianBitEstimator()

    def init_aux_hog(self, config):
        print('Using HOG decoder:', config.aux_dec_model)
        self.decoder_hog = TimmViTDecoder(
            in_channels=3, 
            num_latent_tokens=config.num_latent_tokens,
            model_name=config.aux_dec_model,
            model_kwargs={
                'img_size': config.image_size, 'num_frames': config.num_frames,
                'patch_size': config.dec_patch_size, 't_patch_size': config.t_patch_size, 
                'drop_path_rate': 0.
            },
            pretrained=False,
            tuning_method='full',
            tuning_kwargs={'r': 8},
            use_ape=config.use_ape, use_rope=config.use_rope, rope_mixed=config.rope_mixed, rope_theta=config.rope_theta,
            cls_token=config.aux_dec_cls_token,
            codebook_embed_dim=config.codebook_embed_dim,
            # repa=False, repa_layer_index=config.repa_layer_index,
            # res_token=config.use_res_token,
            to_pixel='identity'
        )
        self.post_quant_conv_hog = nn.Linear(config.codebook_embed_dim, self.decoder_hog.embed_dim)
        self.to_pixel_hog = nn.Linear(self.decoder_hog.embed_dim, 108)
        self.hog_generator = HOGGenerator()
        if 'movq' in config.aux_dec_model:
            self.hog_use_movq = True 
            # self.decoder_hog.movq_type = config.movq_type
        else:
            self.hog_use_movq = False

    def init_aux_dino(self, config):
        print('Using DINO decoder:', config.aux_dec_model)
        self.decoder_dino = TimmViTDecoder(
            in_channels=3, 
            num_latent_tokens=config.num_latent_tokens,
            model_name=config.aux_dec_model,
            model_kwargs={
                'img_size': config.image_size, 'num_frames': config.num_frames,
                'patch_size': config.dec_patch_size, 't_patch_size': config.t_patch_size, 
                'drop_path_rate': 0.
            },
            pretrained=False,
            tuning_method='full',
            tuning_kwargs={'r': 8},
            use_ape=config.use_ape, use_rope=config.use_rope, rope_mixed=config.rope_mixed, rope_theta=config.rope_theta,
            cls_token=config.aux_dec_cls_token,
            codebook_embed_dim=config.codebook_embed_dim,
            # repa=False, repa_layer_index=config.repa_layer_index,
            # res_token=config.use_res_token,
            to_pixel='identity'
        )
        self.post_quant_conv_dino = nn.Linear(config.codebook_embed_dim, self.decoder_dino.embed_dim)
        self.to_pixel_dino = nn.Linear(self.decoder_dino.embed_dim, self.repa_model.embed_dim)
        if 'movq' in config.aux_dec_model:
            self.dino_use_movq = True 
            # self.decoder_dino.movq_type = config.movq_type
        else:
            self.dino_use_movq = False

    def init_aux_clip(self, config):
        # self.clip_model = create_model('vit_so400m_patch14_siglip_gap_224', pretrained=True, img_size=config.image_size, patch_size=config.repa_patch_size)
        self.clip_model = create_model('vit_so400m_patch14_siglip_gap_224', pretrained=True, img_size=224, patch_size=14)
        for param in self.clip_model.parameters():
            param.requires_grad = False
        # self.clip_model.dynamic_img_size = True
        self.clip_model.eval()
        self.clip_resize = transforms.Resize(224, antialias=True)
        self.clip_de_scale = Denormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        self.clip_scale = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        print('Using CLIP decoder:', config.aux_dec_model)
        self.decoder_clip = TimmViTDecoder(
            in_channels=3,
            num_latent_tokens=config.num_latent_tokens,
            model_name=config.aux_dec_model,
            model_kwargs={
                'img_size': config.image_size, 'num_frames': config.num_frames,
                'patch_size': config.dec_patch_size, 't_patch_size': config.t_patch_size, 
                'drop_path_rate': 0.
            },
            pretrained=False,
            tuning_method='full',
            tuning_kwargs={'r': 8},
            use_ape=config.use_ape, use_rope=config.use_rope, rope_mixed=config.rope_mixed, rope_theta=config.rope_theta,
            cls_token=config.aux_dec_cls_token,
            codebook_embed_dim=config.codebook_embed_dim,
            # repa=False, repa_layer_index=config.repa_layer_index,
            # res_token=config.use_res_token,
            to_pixel='identity'
        )
        self.post_quant_conv_clip = nn.Linear(config.codebook_embed_dim, self.decoder_clip.embed_dim)
        self.to_pixel_clip = nn.Linear(self.decoder_clip.embed_dim, self.clip_model.embed_dim)
        if 'movq' in config.aux_dec_model:
            self.clip_use_movq = True 
            # self.decoder_clip.movq_type = config.movq_type
        else:
            self.clip_use_movq = False

    def sample_orders(self, bsz, seq_len):
        return torch.stack([torch.randperm(seq_len) for _ in range(bsz)], dim=0)

    def random_latent_masking(self, x, orders):
        bsz, seq_len = x.size(0), x.size(1)
        mask = torch.zeros(bsz, seq_len, dtype=torch.bool, device=x.device)
        # stats.truncnorm.rvs
        mask_ratios = self.mask_latent_ratio_generator.rvs(size=bsz)
        
        for i in range(bsz):
            ratio = mask_ratios[i]
            num_mask = int(seq_len * ratio)
            indices = orders[i][:num_mask]
            mask[i, indices] = True
            
        return mask

    def encode(self, x, num_frames=None, fps=None, raw_num_frames=None, frame_pts=None):
        if self.training:
            h, mask = self.encoder(x, return_mask=True, num_frames=num_frames, fps=fps, raw_num_frames=raw_num_frames, frame_pts=frame_pts)
        else:
            h = self.encoder(x, num_frames=num_frames, fps=fps, raw_num_frames=raw_num_frames, frame_pts=frame_pts)
        quant = self.quant_conv(h) # b m c
        emb_loss = {}
        info = None

        if self.rate_loss_weight > 0 and self.training:
            rate = self.bit_estimator(quant)
            b, n, c = quant.shape
            rate = rate / (b * n)
            emb_loss['entropy_loss'] = self.rate_loss_weight * rate # entropy loss

        is_video = x.ndim == 5
        
        if self.repa and self.training:
            # get z from repa_encoder
            b = x.shape[0]
            if is_video:
                x = rearrange(x, 'b c f h w -> (b f) c h w') # for dino
            rescale_x = self.scale(self.de_scale(x)).to(x.dtype)
            z = self.repa_model.forward_features(rescale_x)[:, self.repa_model.num_prefix_tokens:]
            if is_video:
                z = rearrange(z, '(b f) n c -> b (f n) c', b=b) # for dino

            # taking average over spatial dimension
            if self.repa_align == 'global':
                z = z.mean(dim=1)
                z_hat = quant.mean(dim=1)
                # calculate repa loss
                z_hat = self.projection(z_hat)
            elif self.repa_align == 'avg_1d':
                z = F.adaptive_avg_pool1d(z.permute(0, 2, 1), quant.shape[1]).permute(0, 2, 1)
                z_hat = quant
                z_hat = self.projection(z_hat)
            elif self.repa_align == 'avg_1d_shuffle':
                # shuffle the length dimension of z and avg
                indices = torch.randperm(z.shape[1])
                z = F.adaptive_avg_pool1d(z[:, indices, :].permute(0, 2, 1) , quant.shape[1]).permute(0, 2, 1)
                z_hat = quant
                z_hat = self.projection(z_hat)
            elif self.repa_align == 'repeat':
                z_hat = self.projection(quant)
                b, l, d = z_hat.shape
                z_hat = z_hat.unsqueeze(2).expand(-1, -1, z.size(1) // l, -1).reshape(b, -1, d)
            

            z = F.normalize(z, dim=-1)
            z_hat = F.normalize(z_hat, dim=-1)
            proj_loss = mean_flat(-(z * z_hat).sum(dim=-1))
            proj_loss = proj_loss.mean()
            proj_loss *= self.repa_loss_weight
            
            emb_loss['repa_loss'] = proj_loss

        if self.latent_drop and self.training:
            orders = self.sample_orders(bsz=quant.size(0), seq_len=quant.size(1)).to(quant.device)
            latent_mask = self.random_latent_masking(quant, orders).unsqueeze(-1)
            quant = torch.where(latent_mask.bool(), self.mask_latent, quant)

        if self.training:
            return quant, emb_loss, info, mask
        else:
            return quant, emb_loss, info

    def decode(self, quant, x=None, h=None, w=None, num_frames=None, fps=None, frame_pts=None):
        tmp_quant = quant 
        quant = self.post_quant_conv(quant)
        if self.use_movq:
            dec = self.decoder(quant, tmp_quant, h, w, num_frames=num_frames, fps=fps, frame_pts=frame_pts)
        else:
            dec = self.decoder(quant, None, h, w, num_frames=num_frames, fps=fps, frame_pts=frame_pts)
        return dec

    def decode_hog(self, quant, x=None, h=None, w=None):
        tmp_quant = quant 
        quant = self.post_quant_conv_hog(quant)
        if self.hog_use_movq:
            dec = self.decoder_hog(quant, tmp_quant, h, w)
        else:
            dec = self.decoder_hog(quant, None, h, w)
        return dec
    
    def decode_dino(self, quant, x=None, h=None, w=None):
        tmp_quant = quant 
        quant = self.post_quant_conv_dino(quant)
        if self.dino_use_movq:
            dec = self.decoder_dino(quant, tmp_quant, h, w)
        else:
            dec = self.decoder_dino(quant, None, h, w)
        return dec

    def decode_clip(self, quant, x=None, h=None, w=None):
        tmp_quant = quant 
        quant = self.post_quant_conv_clip(quant)
        if self.clip_use_movq:
            dec = self.decoder_clip(quant, tmp_quant, h, w)
        else:
            dec = self.decoder_clip(quant, None, h, w)
        return dec

    def forward(self, input, enc_num_frames=None, enc_fps=None, dec_num_frames=None, dec_fps=None):
        b = input.shape[0]
        h, w = input.shape[-2:]
        if self.training:
            quant, diff, info, mask = self.encode(input, num_frames=enc_num_frames, fps=enc_fps)
        else:
            quant, diff, info = self.encode(input, num_frames=enc_num_frames, fps=enc_fps)
        self.quant = quant
        dec = self.decode(quant, x=input, h=h, w=w, num_frames=dec_num_frames, fps=dec_fps)

        is_video = input.ndim == 5

        if self.training:
            if is_video:
                slide_input = rearrange(input, 'b c f h w -> (b f) c h w')
            else:
                slide_input = input
            # MAETok is trained using mask modeling at encoder, 
            # with a mask ratio of 40-60%, and **predict multiple target features**
            # raw mask indicates the tokens to mask
            # here we can invert it to indicate the valid tokens
            # mask = ~mask
            # decode hog feature
            if self.aux_hog_decoder:
                dec_hog = self.decode_hog(quant, x=input, h=h, w=w)   
                dec_hog = self.to_pixel_hog(dec_hog)
                # get hog_target
                z_hog = self.hog_generator(slide_input)
                if is_video:
                    z_hog = rearrange(z_hog, '(b f) n c -> b (f n) c', b=b)
                if self.aux_loss_mask:
                    hog_rec_loss = F.mse_loss(dec_hog, z_hog, reduction='none')
                    hog_rec_loss = (hog_rec_loss * mask).sum() / mask.sum() / z_hog.size(-1)
                else:
                    hog_rec_loss = F.mse_loss(dec_hog, z_hog)
                diff['hog_loss'] = hog_rec_loss
        
            # decode dinov2 feature
            if self.aux_dino_decoder:
                dec_dino = self.decode_dino(quant, x=input, h=h, w=w)
                dec_dino = self.to_pixel_dino(dec_dino)
                
                # get z from repa_encoder
                rescale_x = self.scale(self.de_scale(slide_input))
                rescale_x = rescale_x.to(slide_input.dtype)
                z_dino = self.repa_model.forward_features(rescale_x)[:, self.repa_model.num_prefix_tokens:]

                z_dino = F.normalize(z_dino, dim=-1)
                dec_dino = F.normalize(dec_dino, dim=-1)
                if is_video:
                    z_dino = rearrange(z_dino, '(b f) n c -> b (f n) c', b=b)
                if self.aux_loss_mask:
                    dino_rec_loss = -(dec_dino * z_dino).sum(dim=-1, keepdim=True)
                    dino_rec_loss = (dino_rec_loss * mask).sum() / mask.sum()
                else:
                    dino_rec_loss = mean_flat(-(dec_dino * z_dino).sum(dim=-1))
                    dino_rec_loss = dino_rec_loss.mean()
                diff['dino_loss'] = dino_rec_loss
            
            # decode clip feature
            if self.aux_clip_decoder:
                dec_clip = self.decode_clip(quant, x=input, h=h, w=w)
                dec_clip = self.to_pixel_clip(dec_clip)
                # get clip_target
                rescale_x = self.clip_scale(self.clip_resize(self.clip_de_scale(slide_input)))
                rescale_x = rescale_x.to(slide_input.dtype)
                z_clip = self.clip_model.forward_features(rescale_x)[:, self.clip_model.num_prefix_tokens:]
                z_clip = F.normalize(z_clip, dim=-1)
                dec_clip = F.normalize(dec_clip, dim=-1)
                if is_video:
                    z_clip = rearrange(z_clip, '(b f) n c -> b (f n) c', b=b)
                if self.aux_loss_mask:
                    clip_rec_loss = -(dec_clip * z_clip).sum(dim=-1, keepdim=True)
                    clip_rec_loss = (clip_rec_loss * mask).sum() / mask.sum()
                else:
                    clip_rec_loss = mean_flat(-(dec_clip * z_clip).sum(dim=-1))
                    clip_rec_loss = clip_rec_loss.mean()   
                diff['clip_loss'] = clip_rec_loss

        return dec, diff, info



#################################################################################
#                              VQ Model Configs                                 #
#################################################################################
def VQ_8(**kwargs):
    return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))

def VQ_16(**kwargs):
    return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))

def KL_8(**kwargs):
    return KLModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))

def KL_16(**kwargs):
    return KLModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))

def AE_16(**kwargs):
    return AEModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))

def SoftVQ(**kwargs):
    return SoftVQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))


VQ_models = {
    'AE-16': AE_16,
    'VQ-16': VQ_16, 'VQ-8': VQ_8,
    'KL-16': KL_16, 'KL-8': KL_8,
    'SoftVQ': SoftVQ,
    }


if __name__ == '__main__':
    import ruamel.yaml as yaml
    from PIL import Image
    import cv2
    import numpy as np
    from utils.data import center_crop_arr
    from einops import rearrange
    
    # ckpt_path = 'experiments/tokenizer/exp069-f16-t512-p8-fp4-ae-b-L1-gan-std-rope3d-rd1e-3-k600-ds_config_bs4x1-exp051_175k/checkpoints/vq_0175000/pytorch_model.bin'
    config_path = 'configs/var_fps/f16-t512-p8-fp4-ae-b-L1-gan-rope_6head-mixfps-k600.yaml'

    with open(config_path, 'r', encoding='utf-8') as f:
        file_yaml = yaml.YAML()
        config = file_yaml.load(f)
    
    ae_model = VQ_models[config['vq_model']](
        image_size=config['image_size'],
        num_frames=config['num_frames'],
        # codebook_size=config['codebook_size'],
        codebook_embed_dim=config['codebook_embed_dim'],
        # codebook_l2_norm=config['codebook_l2_norm'],
        commit_loss_beta=config['commit_loss_beta'],
        entropy_loss_ratio=config['entropy_loss_ratio'],
        vq_loss_ratio=config['vq_loss_ratio'],
        kl_loss_weight=config['kl_loss_weight'],
        dropout_p=config['dropout_p'],
        enc_type=config['enc_type'],
        encoder_model=config['encoder_model'],
        dec_type=config['dec_type'],
        decoder_model=config['decoder_model'],
        num_latent_tokens=config['num_latent_tokens'],
        enc_patch_size=config['encoder_patch_size'],
        dec_patch_size=config['decoder_patch_size'],
        t_patch_size=config.get('t_patch_size', 1),
        enc_pretrained=False,
        dec_pretrained=False,
        tau=config['tau'],
        repa=config['repa'],
        repa_model=config['repa_model'],
        repa_patch_size=config['repa_patch_size'],
        repa_proj_dim=config['repa_proj_dim'],
        repa_loss_weight=config['repa_loss_weight'],
        repa_align=config['repa_align'],
        # num_codebooks=config['num_codebooks'],
        use_ape=config['use_ape'],
        use_rope=config['use_rope'],
        rope_mixed=config['rope_mixed'],
        rope_dim=config.get('rope_dim', None),
        rope_sbm=config.get('rope_sbm', False),
        rope_heads=config.get('rope_heads', None),
        rope_layers=config.get('rope_layers', None),
        rope_theta=config['rope_theta'],
        rope_theta_t=config.get('rope_theta_t', 100.0),
        enc_channel_attn=config.get('enc_channel_attn', False),
        dec_channel_attn=config.get('dec_channel_attn', False),
        enc_attn_pe=config.get('enc_attn_pe', False),
        dec_attn_pe=config.get('dec_attn_pe', False),
        enc_attn_latent_rope=config.get('enc_attn_latent_rope', False),
        dec_attn_latent_rope=config.get('dec_attn_latent_rope', False),
        dec_seperate_mask_token=config.get('dec_seperate_mask_token', False),
        rate_loss_weight=config.get('rate_loss_weight', 0),
        enc_token_drop=config.get('enc_token_drop', 0),
        enc_token_drop_max=config.get('enc_token_drop_max', 0),
        aux_dec_model=config.get('aux_dec_model', 'vit_tinytiny_patch14_dinov2_movq'),
        aux_loss_mask=config.get('aux_loss_mask', False),
        aux_dec_cls_token=config.get('aux_dec_cls_token', True),
        aux_hog_dec=config.get('aux_hog_dec', False),
        aux_dino_dec=config.get('aux_dino_dec', False),
        aux_clip_dec=config.get('aux_clip_dec', False),
        variable_num_frames=config.get('variable_num_frames', False),
    ).cuda()

    def transcode(tensor: torch.Tensor,
               src_fps: float,
               dst_fps: float,
               t_patch_size: int) -> torch.Tensor:
        F = tensor.shape[2]
        assert 0 < dst_fps <= src_fps

        ratio = src_fps / dst_fps
        new_F = math.floor(F / ratio / t_patch_size) * t_patch_size
        new_fps = new_F / F * src_fps

        idx_float = torch.arange(new_F, dtype=torch.float32) * ratio
        idx = torch.floor(idx_float).long()
        idx = torch.clamp(idx, max=F-1)

        out = tensor[:, :, idx, :, :]
        return out, new_fps, idx
    
    # print(next(ae_model.parameters()).dtype)

    # from utils.misc import load_model_state_dict
    # checkpoint = torch.load(ckpt_path, map_location="cpu")
    # ae_model.load_state_dict(load_model_state_dict(checkpoint, ae_model))
    num_frames = 20
    fps = 30
    x = torch.zeros(1, 3, num_frames, 256, 256).cuda()
    low_fps_x, low_fps, frame_pts = transcode(x, fps, 24, 4)
    low_fps_num_frames = low_fps_x.shape[2]

    x = ae_model.encode(low_fps_x, num_frames=low_fps_num_frames, fps=fps, raw_num_frames=num_frames, frame_pts=frame_pts)[0]

    # ae_model.load_state_dict(checkpoint)
    # # print(checkpoint['model']['encoder.model.pos_embed'].shape)
    # # print(ae_model.encoder.model.pos_embed.shape)
    # print(next(ae_model.parameters()).dtype)
    # exit()

    # in_img = Image.open(data_path)
    # transform = transforms.Compose([
    #     transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config['image_size'])),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    # ])
    # in_img = transform(in_img)[None]
    # with torch.no_grad():
    #     out_img, _, _ = ae_model(in_img)
    # out_img = (torch.clamp(out_img, min=-1, max=1) + 1) / 2
    # out_img = out_img.cpu().numpy()
    # out_img = (out_img * 255).astype(np.uint8)
    # out_img = rearrange(out_img[0], 'c h w -> h w c')
    # out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
    # cv2.imwrite('out.png', out_img)