import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from timm.layers import trunc_normal_
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
    if isinstance(grid_size, int):
        grid_h = np.arange(grid_size, dtype=np.float32)
        grid_w = np.arange(grid_size, dtype=np.float32)
        grid_size = [grid_size, grid_size]
    else:
        assert len(grid_size) == 2
        grid_h = np.arange(grid_size[0], dtype=np.float32)
        grid_w = np.arange(grid_size[1], dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)
    
    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])

    emb = np.concatenate([emb_h, emb_w], axis=1)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega

    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega) 

    emb_sin = np.sin(out)
    emb_cos = np.cos(out)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)

    def forward(self, x): 
        return x * self.gamma
    

class ViTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_rate=0.1, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-5)
        self.attn = Attention(hidden_size, num_heads=num_heads, 
                              qkv_bias=True, qk_norm=True, norm_layer=nn.LayerNorm,
                              proj_drop=drop_rate, attn_drop=drop_rate,
                              **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-5)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.ls1 = LayerScale(hidden_size)
        self.ls2 = LayerScale(hidden_size)

    def forward(self, x):
        x = x + self.ls1(self.attn(self.norm1(x)))
        x = x + self.ls2(self.mlp(self.norm2(x)))
        return x


class ViT(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=4,
        hidden_size=192,
        depth=8,
        num_heads=6,
        mlp_ratio=4.0,
        drop_rate=0.1,
        cls_num=4,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.num_heads = num_heads

        self.num_patches = (input_size // patch_size)**2
        self.pos_embed = nn.Parameter(
            torch.zeros(1, cls_num + self.num_patches, hidden_size), 
            requires_grad=False)
        
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], int(self.num_patches ** 0.5))
        self.pos_embed.data[:, cls_num:, :].copy_(
            torch.from_numpy(pos_embed).float().unsqueeze(0))

        blocks = []
        for _ in range(depth):
            blocks.append(ViTBlock(hidden_size, num_heads, mlp_ratio, drop_rate))
        self.blocks = nn.ModuleList(blocks)
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                trunc_normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

    def forward(self, x):
        x = x + self.pos_embed
        for block in self.blocks:
            x = block(x)

        return x

class Encoder(nn.Module):
    def __init__(
            self, 
            input_size=256, 
            patch_size=16, 
            in_channels=3, 
            hidden_size=768, 
            depth=12, 
            num_heads=6, 
            mlp_ratio=4.0, 
            drop_rate=0.1,
            latent_dim=1536,
            cls_num=4,
            double_out=True,
        ):
        super().__init__()
        self.vit = ViT(
            input_size=input_size,
            patch_size=patch_size,
            hidden_size=hidden_size,
            depth=depth,
            cls_num=cls_num,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
        )
        self.patch_embed = PatchEmbed(
            img_size=input_size,
            patch_size=patch_size,
            in_chans=in_channels,
            embed_dim=hidden_size,
        )
        self.norm_in = nn.LayerNorm(hidden_size, eps=1e-5)
        self.norm_out = nn.LayerNorm(hidden_size, eps=1e-5)
        out_dim = latent_dim * 2 if double_out else latent_dim
        self.out_proj = nn.Linear(hidden_size, out_dim, bias=False)

        self.cls_num = cls_num
        self.cls_token = nn.Parameter(
            torch.zeros(1, self.cls_num, hidden_size), 
            requires_grad=False)
        
        cls_pos = get_1d_sincos_pos_embed_from_grid(
            hidden_size, 
            np.arange(self.cls_num, dtype=np.float32))
        cls_pos = torch.from_numpy(cls_pos).float().unsqueeze(0)
        self.cls_token.copy_(cls_pos)
        self.cls_token.requires_grad = True
        
        trunc_normal_(self.out_proj.weight, 0.02)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.norm_in(x)

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.vit(x)
        x = self.out_proj(self.norm_out(x[:, :self.cls_num, :]))
        return x
    

class ViTBlock_2(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-5)
        self.attn = Attention(hidden_size, num_heads=num_heads, 
                              qkv_bias=True, qk_norm=True, norm_layer=nn.LayerNorm,
                              **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-5)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate='tanh')
        self.mlp = Mlp(in_features=hidden_size, 
                       hidden_features=mlp_hidden_dim, 
                       act_layer=approx_gelu, 
                       drop=0)
        self.ls1 = LayerScale(hidden_size)
        self.ls2 = LayerScale(hidden_size)

    def forward(self, c, x):
        x = x + self.ls1(self.attn(self.norm1(x)))
        x = x + self.ls2(self.mlp(self.norm2(x)))
        return x


class EmbedAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0., **block_kwargs):
        super().__init__()
        self.num_heads = num_heads

        self.wq = nn.Linear(dim, dim, bias=qkv_bias)
        self.wkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = attn_drop
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.q_norm = nn.LayerNorm(dim // num_heads, elementwise_affine=True, bias=False)
        self.k_norm = nn.LayerNorm(dim // num_heads, elementwise_affine=True, bias=False)

    def forward(self, x, emb):
        B, N, C = x.shape
        _, M, _ = emb.shape
        
        q = self.wq(emb).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        kv = self.wkv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
        
        x = nn.functional.scaled_dot_product_attention(
            q, k, v, 
            dropout_p=self.attn_drop if self.training else 0.)
        x = x.permute(0, 2, 1, 3).reshape(B, M, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class EmbedBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_rate=0.1, **block_kwargs):
        super().__init__()
        self.norm_emb = nn.LayerNorm(hidden_size, eps=1e-5)
        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-5)
        self.attn = EmbedAttention(hidden_size, num_heads=num_heads, qkv_bias=True, 
                                   proj_drop=drop_rate, attn_drop=drop_rate,
                                   **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-5)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate='tanh')
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.ls1 = LayerScale(hidden_size)
        self.ls2 = LayerScale(hidden_size)

    def forward(self, x, emb):
        res = self.attn(self.norm1(x), self.norm_emb(emb))
        emb = emb + self.ls1(res)
        emb = emb + self.ls2(self.mlp(self.norm2(emb)))
        return emb
    

class Decoder_d(nn.Module):
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, eps=1e-5)
        self.linear = nn.Linear(hidden_size, out_channels, bias=True)

    def forward(self, x):
        x = self.linear(self.norm_final(x))
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=4,
        in_channels=3,
        out_channels=4, 
        hidden_size=192,
        depth=6,
        num_heads=6,
        mlp_ratio=4.0,
        drop_rate=0.1,
        cls_num=4,
        double_out=True,
    ):
        super().__init__()
        self.in_channels = in_channels * 2 if double_out else in_channels
        self.out_channels = out_channels
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.cls_num = cls_num

        size = input_size // patch_size
        self.num_patches = size * size

        blocks = []
        for _ in range(depth):
            blk = EmbedBlock(hidden_size, num_heads, mlp_ratio, drop_rate)
            blocks.append(blk)
            blk = ViTBlock_2(hidden_size, num_heads, mlp_ratio)
            blocks.append(blk)

        self.blocks = nn.ModuleList(blocks)
    
        self.encoder = nn.Linear(out_channels, hidden_size, bias=True)
        self.cls_norm = nn.LayerNorm(hidden_size)
        self.decoder = Decoder_d(hidden_size, self.in_channels * patch_size ** 2)

        self.query_token = nn.Parameter(
            torch.zeros(1, self.num_patches, hidden_size), 
            requires_grad=False)
        self.query_pos = nn.Parameter(
            torch.zeros(1, self.num_patches, hidden_size), 
            requires_grad=False)
        self.cls_pos = nn.Parameter(
            torch.zeros(1, cls_num, hidden_size), 
            requires_grad=False)

        self.initialize_weights()
        self.init_token()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                trunc_normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
    
    def init_token(self):
        query_token = torch.nn.init.orthogonal_(
            torch.empty((self.num_patches, 
                        self.hidden_size),
                        dtype=torch.float64))
        query_token = F.normalize(query_token, p=2, dim=1)
        
        self.query_token.copy_(
            query_token.float().unsqueeze(0))
        self.query_token.requires_grad = True

        pos_embed = get_2d_sincos_pos_embed(
            self.hidden_size, 
            int(self.num_patches ** 0.5))
        pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
        self.query_pos.copy_(pos_embed)

        cls_pos = get_1d_sincos_pos_embed_from_grid(
            self.hidden_size, 
            np.arange(self.cls_num, dtype=np.float32))
        cls_pos = torch.from_numpy(cls_pos).float().unsqueeze(0)
        self.cls_pos.copy_(cls_pos)
        self.cls_pos.requires_grad = True


    def unpatchify(self, x):
        c = self.in_channels
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def forward(self, x):
        x = self.encoder(x)
        x = self.cls_norm(x)

        x = x + self.cls_pos
        emb = self.query_token + self.query_pos
        emb = emb.expand(x.shape[0], -1, -1)

        for block in self.blocks:
            emb = block(x, emb)

        emb = self.decoder(emb)
        emb = self.unpatchify(emb)  

        return emb

