import torch
import torch.nn as nn
import math
import numpy as np

from torchvision import transforms
from pytorch_pretrained_vit import ViT

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    return model

def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True
    return model

# =========== Helper Layers ========================================================================
class CheckShape(nn.Module):
    def __init__(self, remark, key=None):
        super().__init__()
        self.remark = remark
        self.key = key
    def forward(self, x, **kwargs):
        if self.remark is not None:
            print(self.remark, x.shape)
        
        out = x
        if self.key is not None:
            out = self.key(x)
        return out

class LinearAE(nn.Module):
    def __init__(self, in_size, latent_size):
        super().__init__()
        self.enc = nn.Linear(in_size, latent_size)
        self.dec = nn.Linear(latent_size, in_size)
        
    def forward(self, x):
        enc = self.enc(x)
        return self.dec(enc), enc
    
class LinearProb(nn.Module):
    def __init__(self, num_channel, embed_size, num_classes):
        super().__init__()
        self.fuse_w = nn.Parameter(torch.ones(1, 1, num_channel)/num_channel)

        self.fc = nn.Linear(embed_size, num_classes)

    def forward(self, x): # N, C, E
        final_embed = x.permute(0, 2, 1) * self.fuse_w # N, E, C
        return self.fc(torch.sum(final_embed, dim=2))

class EmptyLater(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, **kwargs):
        return kwargs.get('x')

# =========== Encoder Block ========================================================================
class RNNEncoder(nn.Module):
    def __init__(
        self,
        in_channel,
        emb_size=48,
    ):
        super().__init__()

        self.rnn = nn.LSTM(in_channel, emb_size, 1, batch_first=True)
        

    def forward(self, x):
        out, (_, _) = self.rnn(x.squeeze(1))
        return out
    
class EmbConvBlock(nn.Module):
    def __init__(
        self, 
        in_channel, 
        T_kernel_size, # 8,
        emb_size=48,
        hidden_size=48*4
    ):
        super().__init__()
        
        # Input shape: (N, L, C)
        self.liner = nn.Sequential(
            # CheckShape("Before in"),
            CheckShape(None, key=lambda x: x.unsqueeze(1)), #(N, 1, L, C)
            # Temporal
            nn.Conv2d(1, hidden_size, kernel_size=[T_kernel_size, 1], padding='same'), 
            # nn.Conv2d(1, hidden_size, kernel_size=[T_kernel_size, 1], padding='same', dilation=2), # no warning
            nn.BatchNorm2d(hidden_size), 
            nn.GELU(),
            # Spatial
            nn.Conv2d(hidden_size, emb_size, kernel_size=[1, in_channel], padding='valid'), 
            nn.BatchNorm2d(emb_size), 
            nn.GELU(),
            CheckShape(None, key=lambda x: torch.permute(x, (0, 3, 2, 1)).squeeze(1)), # (N, L, C)
        )

    def forward(self, x):
        # Input shape: (N, L, C)
        out = self.liner(x)
        return out

class EmbConvBlockRGB(nn.Module):
    def __init__(
        self, 
        in_channel, 
        T_kernel_size, # 8,
        emb_size=48,
        hidden_size=48*4
    ):
        super().__init__()
        
        # Input shape: (N, L, C)
        self.liner = nn.Sequential(
            # Temporal
            # nn.Conv2d(1, hidden_size, kernel_size=[T_kernel_size, 1], padding='same'), 
            nn.Conv2d(3, hidden_size, kernel_size=[T_kernel_size, 1], padding='same', dilation=2), # no warning
            #nn.BatchNorm2d(hidden_size), 
            nn.GELU(),
            # Spatial
            nn.Conv2d(hidden_size, emb_size, kernel_size=[1, in_channel], padding='valid'), 
            #nn.BatchNorm2d(emb_size), 
            nn.GELU(),
            CheckShape(None, key=lambda x: torch.permute(x, (0, 3, 2, 1)).squeeze(1)), # (N, L, C)
        )

    def forward(self, x):
        # Input shape: (N, L, C)
        out = self.liner(x)
        return out


# =========== Position Embedding ========================================================================
class tAPE(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):
        super(tAPE, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)  # positional encoding
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
        pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
        pe = scale_factor * pe.unsqueeze(0)
        self.register_buffer('pe', pe)  # this stores the variable in the state_dict (used for non-trainable variables)

    def forward(self, x): # N, L, C
        x = x + self.pe[:, -x.shape[1]:, :]
        return self.dropout(x)
    
# =========== Multi-Head Self-Attention ========================================================================
def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)


def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)

class MultiHeadedSelfAttention(nn.Module):
    """Multi-Headed Dot Product Attention"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None # for visualization

        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim, eps=1e-6)

        self.fc = FeedForward(dim, dim*4, dropout=dropout)

    def forward(self, x, mask=None):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(nn.functional.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return self.fc(self.norm(self.drop(self.proj(h))+x))

class MultiHeadedSelfAttentionMS(nn.Module):
    """Multi-Headed Dot Product Attention"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None # for visualization

        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim, eps=1e-6)

        self.fc = FeedForward(dim, dim*2, dropout=dropout)

    def forward(self, x, mask=None):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # calculate importance
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = x, x, x
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(nn.functional.softmax(scores, dim=-1))

        # calculate relevance
        target_embed = self.fc.fc_liner[2].weight
        TE, _ = target_embed.shape
        target_embed = target_embed.view(self.n_heads, TE // self.n_heads, -1).unsqueeze(1).unsqueeze(0)

        # dot prod
        v_ = v.unsqueeze(-1)
        imp = torch.sum(target_embed * v_, dim=3)

        # denominator (complete cos similarity)
        den = torch.sqrt(torch.sum(target_embed ** 2, dim=-2)) * torch.sqrt(torch.sum(v_ ** 2, dim=-2))
        imp = imp / den

        # take highest correlation
        imp = torch.max(imp, dim=-1).values.unsqueeze(-2)
        scores = scores + imp

        # final output
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return self.norm(self.drop(self.proj(h))+x)
    
################################################

class FeedForward(nn.Module):
    def __init__(self, emb_size, hidden_size, dropout=0.1, add_norm=True):
        super().__init__()
        self.add_norm = add_norm

        self.fc_liner = nn.Sequential(
            nn.Linear(emb_size, hidden_size),
            nn.GELU(),
            # nn.Dropout(p=dropout),
            nn.Linear(hidden_size, emb_size),
            nn.Dropout(p=dropout),
        )

        self.LayerNorm = nn.LayerNorm(emb_size, eps=1e-6)

    def forward(self, x):
        out = self.fc_liner(x)
        if self.add_norm:
            return self.LayerNorm(x + out)
        return out

# =========== Main Transformer Layer ========================================================================
    
class Transformer(nn.Module):
    def __init__(
        self, 
        num_layers=1,
        emb_size=768,
        num_heads=12,
        dropout=0.1,
    ):
        super().__init__()

        self.blocks = nn.ModuleList([
            nn.Sequential(
                MultiHeadedSelfAttention(emb_size, num_heads, dropout),
                # MultiHeadedSelfAttentionMS(emb_size, num_heads, dropout),
            ) for _ in range(num_layers)
        ])

    def forward(self, x):
        # Input shape: (N, L, C)
        for block in self.blocks:
            x = block(x)
        return x # (N, L, E), (N, E)
    
class SiT(nn.Module): # Signal Transformer (SiT)
    def __init__(
        self, 
        num_layers=6,
        emb_size=384,
        num_heads=12,
        dropout=0.1,
        # data related
        in_channel=65
    ):
        super().__init__()

        self.cls = nn.Parameter(torch.rand(1, 1, emb_size), requires_grad=True)

        # Embeddings (Potential Improvement here)
        self.emb = nn.Sequential( # (N, L, C)
            # Patching
            # EmbConvBlock(in_channel, 8, emb_size, hidden_size=emb_size*4),
            EmbConvBlockRGB(in_channel, 8, emb_size, hidden_size=emb_size*4),
            # CheckShape("After embed"),
            # position embedding
            tAPE(emb_size, dropout=dropout, max_len=2048),
        ) # (N, L, E)

        # transformers blocks
        self.transformer = Transformer(
            num_layers=num_layers,
            emb_size=emb_size,
            num_heads=num_heads,
            dropout=dropout
        )

    def forward(self, x): # Input shape: (N, 3, L, F)
        out = self.emb(x) # (N, L, E)
        N, L, E = out.shape
        out = torch.cat((self.cls.expand(N, 1, E), out), dim=1) # prepend [CLS] token
        return self.transformer(out) # (N, L+1, E)

# ===== ADJUST VIT ==========================================

class VitPosEmbedAdjust(nn.Module):
    def __init__(self, origin_pos_embed):
        super().__init__()
        self.pos_embed = origin_pos_embed # 1, 577, 768
        self.origin_L = origin_pos_embed.shape[1]

        self.time_pos_embed = tAPE(768, max_len=2048, scale_factor=1e-5)

    def forward(self, x, **kwargs):
        _, L, _ = x.shape
        if L == self.origin_L:
            return self.time_pos_embed(x + self.pos_embed)
        elif L < self.origin_L:
            return self.time_pos_embed(x + self.pos_embed[:, -L:, :])
        else:
            concat_embed = self.pos_embed
            while concat_embed.shape[1] < L:
                concat_embed = torch.cat((concat_embed, self.pos_embed), dim=1)
            return self.time_pos_embed(x + concat_embed[:, -L:, :])
        
class ViTAdjust(nn.Module):
    def __init__(self, finetune_layers=list(), upper_layer=-1):
        super().__init__()
        self.vit = ViT('B_16_imagenet1k', pretrained=True) # construct and load pretrained weight

        # overwrite Conv2D patching layer, make the stride to be (4, 16)
        # maintain the 16 on the width dimension, change height dimension to 4
        # we have 65 on width, L on height. Suppose to be
        self.vit.patch_embedding.stride = (4, 16)

        # overwrite position embedding
        self.vit.positional_embedding = VitPosEmbedAdjust(self.vit.positional_embedding.pos_embedding)

        freeze_model(self.vit)

        # fine tune certain layer
        self.upper_layer = upper_layer
        self.finetune_layers = finetune_layers
        for layer in finetune_layers:
            unfreeze_model(self.vit.transformer.blocks[layer])

        # remove never-used modules
        self.vit.fc = None
    
    def train(self, mode):
        super().train(mode)

        # keep freeze layer stay in eval mode
        for b_i in range(len(self.vit.transformer.blocks)):
            if b_i not in self.finetune_layers:
                self.vit.transformer.blocks[b_i].eval()
    
    def upper_forward(self, x, mask=None): # Input shape: (N, 3, L, 65)
         # transform
        N, _, L, F = x.shape
        out = transforms.Compose([
            # transforms.Resize((L+16, 65)), 
            transforms.Normalize(0.5, 0.5),
            lambda im: nn.functional.pad(im, (0, 0, 16, 0), value=0),
        ])(x) # N, 3, L, 384

        # main forward
        # patching with overlap
        out = self.vit.patch_embedding(out) # b,d,gh,gw
        out = out.flatten(2).transpose(1, 2) # b,gh*gw,d
        out = out[:, -L:, :] # b,L,d
        
        # concat [CLS] token
        out = torch.cat((self.vit.class_token.expand(N, -1, -1), out), dim=1) # b,L+1,d
        out = self.vit.positional_embedding(out)

        # transformer
        if self.upper_layer < 0:
            out = self.vit.transformer(out)
        else:
            for b_i in range(self.upper_layer):
                out = self.vit.transformer.blocks[b_i](out,mask=mask)

        return out

    def lower_forward(self, out,mask=None):
        if self.upper_layer >= 0:
            for b_i in range(self.upper_layer, len(self.vit.transformer.blocks)):
                out = self.vit.transformer.blocks[b_i](out,mask=mask)
        
        # norm
        out = self.vit.norm(out)
        return out

    def forward(self, x, mask=None): # Input shape: (N, 3, L, 65)
        out = self.upper_forward(x, mask=mask)
        out = self.lower_forward(out)

        # return
        return out

# test shapes of input and output
if __name__ == '__main__':
    num_sample = 3
    seq_length = 256
    in_channel = 65

    x = torch.rand((num_sample, 3, seq_length, in_channel)) # N, 3, L, F

    model = SiT(
        num_layers=6,
        emb_size=384,
        # data related
        in_channel=in_channel
    )

    # count number of params
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Num param:", pytorch_total_params)

    print("x shape:", x.shape)
    y = model(x)
    print('y shape:', y.shape) # (3, 257, 384)
    
