import pdb
import torch
import numpy as np
from torch import nn, einsum
import torch.nn.functional as F
import math
from functools import partial
from einops import rearrange
from timm.models.layers import trunc_normal_
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from vqgan.vqperceptual import VQLPIPSWithDiscriminator

from vqgan.patch_vqgan import VectorQuantizer, EMAVectorQuantizer

from modeling_finetune import VisionTransformer

def get_default_config(name='encoder', model='base'):
    default_params = dict(
        img_size=224, 
        patch_size=16, 
        in_chans=3, 
        num_classes=1000, 
        embed_dim=768, 
        depth=12,
        num_heads=12, 
        mlp_ratio=4., 
        qkv_bias=False, 
        qk_scale=None, 
        drop_rate=0., 
        attn_drop_rate=0.,
        drop_path_rate=0., 
        norm_layer=partial(nn.LayerNorm, eps=1e-6), 
        init_values=0.,
        use_abs_pos_emb=True, 
        use_rel_pos_bias=False, 
        use_shared_rel_pos_bias=False,
        use_mean_pooling=True, 
        init_scale=0.001)

    if model == 'base':
        pass
    
    if name == 'encoder':
        default_params['num_classes'] = 0
    elif name == 'decoder':
        default_params['img_size'] = 14
        default_params['patch_size'] = 1
        default_params['in_chans'] = 32
        default_params['num_classes'] = 0
    else:
        raise NotImplementedError

    return default_params

def get_default_vitvqgan_config():
    encoder_config = get_default_config(name='encoder', model='base')
    decoder_config = get_default_config(name='decoder', model='base')

    loss_config = {
        'disc_start': 250001,
        'disc_factor': 0,
        'disc_weight': 0,
        'perceptual_type': 'vgg',
        'perceptual_weight': 1.0,
        'logitlaplace_loss_weight': 0.,
        'codebook_weight': 1.0,
    }
    return encoder_config, decoder_config, loss_config

class ViTVQGAN(nn.Module):
    def __init__(self,
                 encoder_config,
                 decoder_config,
                 loss_config,
                 n_embed=8192, 
                 embed_dim=32,
                 ignore_keys=[],
                 quantizer_type='EMAVQ',
                 embed_ema=True,
                 quantizer_dis_type='cosine',
                 decay=0.99,
                 orthogonal_reg_weight=0.,
                 orthogonal_reg_active_codes_only=False,
                 orthogonal_reg_max_codes=None,
                 ckpt_path=None,
                 norm_target=True,
                 process_type='default',
                 rec_out_channels=3,
                 **kwargs
                 ):
        super().__init__()
        if decoder_config['in_chans'] != embed_dim:
            print(f"Rewrite the in_chans in decoder from {decoder_config['in_chans']} to {embed_dim}")
            decoder_config['in_chans'] = embed_dim
        # encoder & decode params
        self.encoder = VisionTransformer(**encoder_config)
        self.decoder = VisionTransformer(**decoder_config)

        print(f'using {quantizer_type} vector quantizer ({n_embed}x{embed_dim}) with {quantizer_dis_type} distance')
        if quantizer_type == 'EMAVQ':
            self.quantize = EMAVectorQuantizer(n_embed, embed_dim, beta=0.25, decay=decay, distance_type=quantizer_dis_type, embed_ema=True,
                orthogonal_reg_weight=orthogonal_reg_weight, orthogonal_reg_active_codes_only=orthogonal_reg_active_codes_only,
                orthogonal_reg_max_codes=orthogonal_reg_max_codes)
        elif quantizer_type == 'VQ':
            # self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
            self.quantize = EMAVectorQuantizer(n_embed, embed_dim, beta=0.25, decay=decay, distance_type=quantizer_dis_type, embed_ema=False,
                orthogonal_reg_weight=orthogonal_reg_weight, orthogonal_reg_active_codes_only=orthogonal_reg_active_codes_only,
                orthogonal_reg_max_codes=orthogonal_reg_max_codes)
        else:
            raise NotImplementedError(f"Unsupport {quantizer_type} vector quantizer")
        
        self.loss = None if loss_config is None else VQLPIPSWithDiscriminator(**loss_config)
        
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        
        print(f"Norm target for reconstruciton: {norm_target}")
        self.norm_target = norm_target
        self.patch_size = self.encoder.patch_embed.patch_size[0]
        
        self.token_shape = (encoder_config['img_size'] // self.patch_size, encoder_config['img_size'] // self.patch_size)

        # task layer
        self.encode_task_layer = nn.Sequential(
            nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']),
            nn.Tanh(),
            nn.Linear(encoder_config['embed_dim'], embed_dim),
        )
        
        if (loss_config is not None) and (loss_config['logitlaplace_loss_weight'] > 0.):
            assert process_type == 'dall-e', f"Image processer should following dall-e, not {process_type}"
            assert rec_out_channels == 6, f"the rec output channel should =6 when logit-laplace loss is enabled but get {rec_out_channels}"
            
        self.rec_out_channels = rec_out_channels
        self.decode_task_layer = nn.Sequential(
            nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']),
            nn.Tanh(),
            nn.Linear(decoder_config['embed_dim'], self.patch_size * self.patch_size * self.rec_out_channels),
        )
        
        self.process_type = process_type # in ['default', 'dall-e']
        self.logit_laplace_eps = 0.1
        self.kwargs = kwargs
        
        self.encode_task_layer.apply(self._init_weights)
        self.decode_task_layer.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")
        if 'model' in sd:
            sd = sd['model']
        else:
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print(f"PatchVQGAN: Deleting key {k} from state_dict.")
                    del sd[k]
        missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
        print(f"missing_keys: {missing_keys}")
        print(f"unexpected_keys: {unexpected_keys}")
        print(f"PatchVQGAN: Restored from {path}")

    @property
    def device(self):
        return self.encoder.cls_token.device

    def pre_process(self, data):
        if self.process_type == 'default':
            # TODO: modify for adapt
            data = data.to(self.device)
            if data.max() <= 1.:
                data = data * 255.
            data = data / 127.5 - 1.0
        elif self.process_type == 'dall-e':
            data = (1 - 2 * self.logit_laplace_eps) * data + self.logit_laplace_eps
        elif self.process_type == 'imagenet_norm':
            mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(self.device)[None, :, None, None]
            std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(self.device)[None, :, None, None]
            data = (data - mean) / std
        return data

    def post_process(self, data):
        if self.process_type == 'default':
            # TODO: implement the norm target
            data = (data + 1.0) * 127.5
            data = torch.clamp(data, min=0.0, max=255.0)
        elif self.process_type == 'dall-e':
            data = torch.clamp((data - self.logit_laplace_eps) / (1 - 2 * self.logit_laplace_eps), 0, 1) * 255.
        elif self.process_type == 'imagenet_norm':
            mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(self.device)[None, :, None, None]
            std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(self.device)[None, :, None, None]
            data = data * std + mean
        return data
        
    def get_number_of_tokens(self):
        return self.quantize.n_e

    def get_tokens(self, data, **kwargs):
        # with torch.cuda.amp.autocast():
        data = self.pre_process(data)
        x = self.encoder(data, return_patch_tokens=True)
        x = self.encode_task_layer(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=self.token_shape[0], w=self.token_shape[1]) # reshape for quantizer
        idx = self.quantize(x)['index']

        output = {}
        output['token'] = idx.view(idx.shape[0], -1)
        output['input_img'] = data

        return output

    def encode(self, x):
        # pdb.set_trace()
        h = self.encoder(x, return_patch_tokens=True)
        h = self.encode_task_layer(h)
        h = rearrange(h, 'b (h w) c -> b c h w', h=self.token_shape[0], w=self.token_shape[1]) # reshape for quantizer
        quant_out = self.quantize(h)
        quant = quant_out['quantize']
        emb_loss = quant_out['quantize_loss']

        return quant, emb_loss
    
    def decode(self, quant):
        rec = self.decoder(quant, return_patch_tokens=True)
        rec = self.decode_task_layer(rec)
        return rec
    
    def decode_img(self, token, input_img=None):
        patch_size = self.patch_size
        
        bhw = (token.shape[0], self.token_shape[0], self.token_shape[1])
        quant = self.quantize.get_codebook_entry(token.view(-1), shape=bhw)
        rec = self.decoder(quant, return_patch_tokens=True)
        rec = self.decode_task_layer(rec)
        if self.rec_out_channels == 6:
            rec = rec[:, :, :rec.shape[-1] // 2].sigmoid()
        if self.norm_target:
            assert input_img is not None
            _, _, h, w = input_img.shape
            images_squeeze = rearrange(input_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size, p2=patch_size)
            rec = rearrange(rec, 'b n (p c) -> b n p c', c=3)
            rec = rec * (images_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + images_squeeze.mean(dim=-2, keepdim=True)
            rec = rearrange(rec, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=h//patch_size, w=w//patch_size)
        else:
            rec = rearrange(rec, 'b n (p c) -> b n p c', c=3)
            rec = rearrange(rec, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=self.token_shape[0], w=self.token_shape[1])
        rec = self.post_process(rec)
        return rec
    
    def get_codebook_indices(self, x):
        return self.get_tokens(x)['token']

    def forward(self, x):
        """
        x: shape [B, 3, H, W] in [0, 1]
        """
        # pdb.set_trace()
        _, _, h, w = x.shape
        x = self.pre_process(x) # rescale to [-1, 1]
        quant, emb_loss = self.encode(x)
        xrec = self.decode(quant)
        
        if self.training:
            aeloss, log_dict_ae = self.loss(emb_loss, x, xrec, 0, 0, last_layer=None, split="train",
                                            norm_target=self.norm_target, patch_size=self.patch_size, img_h=h, img_w=w, rec_out_channels=self.rec_out_channels)
        else:
            aeloss, log_dict_ae = self.loss(emb_loss, x, xrec, 0, 0, last_layer=None, split="val",
                                            norm_target=self.norm_target, patch_size=self.patch_size, img_h=h, img_w=w, rec_out_channels=self.rec_out_channels)
        return aeloss, log_dict_ae


if __name__ == '__main__':
    logits = torch.tensor([[0, 1, 2], [4,5,6]])
    mask = ~(logits > 2)

    print(mask)






