import torch.nn.functional as F
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from util.pos_embed import get_2d_sincos_pos_embed

class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()

        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1 * dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s

def actvn(x):
    out = F.leaky_relu(x, 2e-1)
    return out


class Encoder_ResNet_AE(nn.Module):

    def __init__(self, latent_dim=768, in_chans=3, patch_size=16, numlayer=(2, 0, 3), nfilter=16):

        super().__init__()

        self.latent_dim = latent_dim
        self.in_chans = in_chans
        self.patch_size = patch_size
        self.numlayer = numlayer
        self.nfilter = nfilter

        layers = nn.ModuleList()
        layers.append(nn.Sequential(
            nn.ConvTranspose2d(self.in_chans, 2 * nfilter, kernel_size=4, stride=2, padding=1),  
            ResnetBlock(2 * nfilter, nfilter)
        ))
        for _ in range(self.numlayer[0] - 1):  
            layers.append(nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.ConvTranspose2d(nfilter, nfilter, kernel_size=4, stride=2, padding=1),  
                ResnetBlock(nfilter, nfilter)
            ))
        for _ in range(self.numlayer[1]):  
            layers.append(nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.Conv2d(nfilter, nfilter, kernel_size=3, stride=1, padding=1),  
                ResnetBlock(nfilter, nfilter)
            ))
        for _ in range(self.numlayer[2]):  
            layers.append(nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.Conv2d(nfilter, nfilter, kernel_size=4, stride=2, padding=1), 
                ResnetBlock(nfilter, nfilter)
            ))
        self.layers = layers
        self.depth = len(layers)

        self.dim_lastfeat = int(patch_size * (2 ** numlayer[0]) / (2 ** numlayer[2]))
        self.embedding = nn.Linear(nfilter * self.dim_lastfeat ** 2, self.latent_dim)

    def patchify(self, imgs):
 
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % self.patch_size == 0

        h = w = imgs.shape[2] // self.patch_size
        x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, self.patch_size, w, self.patch_size))
        x = torch.einsum('nchpwq->nhwcpq', x)
        x = x.reshape(shape=(imgs.shape[0] * h * w, self.in_chans, self.patch_size, self.patch_size))
        return x

    def forward(self, imgs):
        patch_input = self.patchify(imgs)  

        out = patch_input
        for i in range(self.depth):
            out = self.layers[i](out)
        out = actvn(out)
        enc_output = self.embedding(out.reshape(patch_input.shape[0], -1))

        return patch_input, enc_output


class Decoder_ResNet_AE(nn.Module):

    def __init__(self, latent_dim=768, in_chans=3, patch_size=16, num_patch=64, numlayer=(2, 0, 3), nfilter=16):

        super().__init__()

        self.latent_dim = latent_dim
        self.in_chans = in_chans
        self.patch_size = patch_size
        self.num_patch = num_patch
        self.numlayer = numlayer
        self.nfilter = nfilter
        self.dim_lastfeat = int(patch_size * (2 ** numlayer[0]) / (2 ** numlayer[2]))

        layers = nn.ModuleList()
        layers.append(nn.Linear(self.latent_dim, self.nfilter * self.dim_lastfeat ** 2))
        layers.append(ResnetBlock(nfilter, nfilter))
        for _ in range(self.numlayer[2]):  
            layers.append(nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.ConvTranspose2d(nfilter, nfilter, kernel_size=4, stride=2, padding=1),  
                ResnetBlock(nfilter, nfilter)
            ))
        for _ in range(self.numlayer[1]):  
            layers.append(nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.Conv2d(nfilter, nfilter, kernel_size=3, stride=1, padding=1), 
                ResnetBlock(nfilter, nfilter)
            ))
        for _ in range(self.numlayer[0] - 1):  
            layers.append(nn.Sequential(
                nn.LeakyReLU(0.2),
                nn.Conv2d(nfilter, nfilter, kernel_size=4, stride=2, padding=1),  
                ResnetBlock(nfilter, nfilter)
            ))
        layers.append(nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(nfilter, in_chans, kernel_size=4, stride=2, padding=1), 
            nn.Tanh()
        ))
        self.layers = layers
        self.depth = len(layers)

    def unpatchify(self, x, bsz):
     
        h = w = self.num_patch ** 0.5
        h = int(h)
        w = int(w)
        assert h * w == self.num_patch

        x = x.reshape(shape=(bsz, h, w, self.in_chans, self.patch_size, self.patch_size))
        x = torch.einsum('nhwcpq->nchpwq', x)
        recon_img = x.reshape(shape=(bsz, self.in_chans, h * self.patch_size, w * self.patch_size))

        return recon_img

    def forward(self, z, bsz, to_img=False):
        out = z
        for i in range(self.depth):
            out = self.layers[i](out)
            if i == 0:
                out = out.reshape(z.shape[0], self.nfilter, self.dim_lastfeat, self.dim_lastfeat)
        recon_patch = out

        if to_img:
            recon_img = self.unpatchify(recon_patch, bsz)  
            return recon_img
        else:
            return recon_patch

    
class UniViT(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 x_dim=768, z_dim=128,
                 embed_dim=1024, depth=24, num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 ELastNorm='LR', DLastNorm='LR'):
        super().__init__()

        self.in_chans = in_chans
        self.num_patches = (img_size // patch_size) ** 2
        self.L = self.num_patches + 1
        self.patch_size = patch_size
        self.img_size = img_size
        self.z_dim = z_dim
        self.x_dim = x_dim

        self.input_embed = nn.Sequential(  
            nn.Linear(x_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, embed_dim),
        )
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim),
                                      requires_grad=False)  

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        if ELastNorm == 'LN':
            self.encoder_norm = norm_layer(embed_dim)
        elif ELastNorm == 'NO':
            self.encoder_norm = nn.Identity(embed_dim)
        elif ELastNorm == 'IN':
            self.encoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif ELastNorm == 'LR':
            self.encoder_norm = nn.LeakyReLU(0.2)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.head_qz = nn.Sequential(  
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * z_dim),
        )
      
        self.z_embed = nn.Sequential(  
            nn.Linear(z_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, embed_dim),
        )  

        if DLastNorm == 'LN':
            self.decoder_norm = norm_layer(embed_dim)
        elif DLastNorm == 'NO':
            self.decoder_norm = nn.Identity(embed_dim)
        elif DLastNorm == 'IN':
            self.decoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif DLastNorm == 'LR':
            self.decoder_norm = nn.LeakyReLU(0.2)

        self.decoder_pred = nn.Sequential(
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, x_dim + 1),
        )  

        self.initialize_weights()

    def initialize_weights(self):

        pos_embed_all = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], int(self.num_patches ** .5),
            nos_token=True, cls_token=True) 
        pos_embed = pos_embed_all[:-1, ...] 
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def forward_encoder(self, bar_imgs):

        x = self.input_embed(bar_imgs)  

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  
        cx = torch.cat([cls_token, x], dim=1)

        cx = cx + self.pos_embed

        for blk in self.blocks:
            cx = blk(cx)
        cx_embed = self.encoder_norm(cx)

        mu_zqst, lstd_zqst = torch.chunk(self.head_qz(cx_embed[:, 0, :]), chunks=2, dim=-1)

        return mu_zqst, lstd_zqst

    def forward_decoder(self, z):

        latent_z = self.z_embed(z)  
    
        x = self.mask_token.expand(z.shape[0], self.num_patches, -1) 
        x = torch.cat([latent_z.unsqueeze(1), x], dim=1)

        x = x + self.pos_embed

        for blk in self.blocks:
            x = blk(x)
        dec_embed = self.decoder_norm(x)

        latent_pred_patch = self.decoder_pred(dec_embed[:, 1:, :])  
        mu_xp, lstd_xp = latent_pred_patch[:, :, :-1], latent_pred_patch[:, -1:, -1:]

        lstd_xp = torch.zeros_like(lstd_xp)

        return mu_xp, lstd_xp

    def forward(self, bar_imgs):
        
        x = bar_imgs    
        mu_z, lstd_z = self.forward_encoder(x)
        z = mu_z + lstd_z.exp() * torch.randn_like(mu_z)  
        mu_xp, lstd_xp = self.forward_decoder(z)

        return mu_xp


class ViTPatchAE(nn.Module):

    def __init__(
            self,
            input_size=128,
            latent_dim=256,
            in_chans=3,
            patch_size=16,
            noise_std=1,
            nfilter=16,
            numlayer=(2, 0, 3),
            vit_model = None
    ):

        super().__init__()
        self.input_size = input_size  # 128
        self.latent_dim = latent_dim
        self.in_chans = in_chans
        self.patch_size = patch_size
        self.num_patch = (input_size // patch_size) ** 2
        self.noise_std = noise_std
        self.nfilter = nfilter
        self.numlayer = numlayer

        self.encoder = Encoder_ResNet_AE(latent_dim=latent_dim, in_chans=in_chans, patch_size=patch_size,
                                         numlayer=numlayer, nfilter=nfilter)
        self.decoder = Decoder_ResNet_AE(latent_dim=latent_dim, in_chans=in_chans, patch_size=patch_size,
                                         num_patch=self.num_patch, numlayer=numlayer, nfilter=nfilter)

        self.vit_model = vit_model

    def forward(self, imgs, detach_bar, z_loss_type):

        patch_input, bar = self.encoder(imgs)  

        noise_bar = bar + self.noise_std * torch.randn_like(bar)
        recon_patch = self.decoder(noise_bar, imgs.shape[0]) 

        bar_input = bar.reshape(-1, self.num_patch, self.latent_dim) 
        rec_bar = self.vit_model(bar_input) 

        x_loss = F.mse_loss(
                recon_patch.reshape(recon_patch.shape[0], -1), patch_input.reshape(patch_input.shape[0], -1), reduction="mean"
            )
        
        if z_loss_type == 'MSE':
            if detach_bar:
                bar_detached = bar.detach()
                rec_bar = rec_bar.reshape(-1, self.latent_dim)
                z_loss = F.mse_loss(rec_bar, bar_detached, reduction="mean")
            else:
                rec_bar = rec_bar.reshape(-1, self.latent_dim)
                z_loss = F.mse_loss(rec_bar, bar, reduction="mean")
            
        bar_loss = bar.pow(2).mean()

        return {'x_loss': x_loss,
                'z_loss': z_loss,
                'bar_loss': bar_loss,
                'bar': bar,
                'rec_bar': rec_bar
                }