import torch
import torch.nn as nn
import math
import numpy as np

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("DEVICE:", DEVICE)

# =========== Helper Layers ========================================================================
class CheckShape(nn.Module):
    def __init__(self, remark, key=None):
        super().__init__()
        self.remark = remark
        self.key = key
    def forward(self, x, **kwargs):
        if self.remark is not None:
            print(self.remark, x.shape)
        
        out = x
        if self.key is not None:
            out = self.key(x)
        return out

class VAE_Latent(nn.Module):
    def __init__(self, emb_size, out_size):
        super().__init__()

        self.mu = nn.Linear(emb_size, out_size)
        self.var = nn.Sequential(
            nn.Linear(emb_size, out_size),
            nn.Softplus()
        )
        
    def forward(self, x, latent_only=True):
        # generate mean and variance
        mu, var = self.mu(x), self.var(x)

        # reparametrization trick
        if self.training:
            eps = torch.randn_like(var).to(DEVICE)
            z = mu + var*eps
        else:
            z = mu
        
        # output
        if latent_only:
            return z
        return z, mu, var

class LinearAE(nn.Module):
    def __init__(self, in_size, latent_size):
        super().__init__()
        self.enc = nn.Linear(in_size, latent_size)
        self.dec = nn.Linear(latent_size, in_size)
        
    def forward(self, x):
        enc = self.enc(x)
        return self.dec(enc), enc
    
class LinearProb(nn.Module):
    def __init__(self, num_channel, embed_size, num_classes):
        super().__init__()
        self.fuse_w = nn.Parameter(torch.ones(1, 1, num_channel)/num_channel)

        self.fc = nn.Sequential(
            nn.Linear(embed_size, num_classes),
        )

    def forward(self, x): # N, C, E
        final_embed = x.permute(0, 2, 1) * self.fuse_w # N, E, C
        return self.fc(torch.sum(final_embed, dim=2))

class EmptyLater(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, **kwargs):
        return kwargs.get('x')

class VitPosEmbedAdjust(nn.Module):
    def __init__(self, origin_pos_embed):
        super().__init__()
        self.pos_embed = origin_pos_embed # 1, 577, 768
        self.origin_L = origin_pos_embed.shape[1]

        # exp_decay_embed = 0.1 * torch.tensor([0.999 ** a for a in range(2048, -1, -1)]).view(1, -1, 1).to(DEVICE)
        # self.time_pos_embed = nn.Sequential(CheckShape(None, key=lambda x: x + exp_decay_embed[:, -x.shape[1]:, :]))
        self.time_pos_embed = tAPE(768, max_len=2048, scale_factor=1e-5)
        # self.time_pos_embed = nn.Sequential()

    def forward(self, x, **kwargs):
        _, L, _ = x.shape
        if L == self.origin_L:
            # return x + self.pos_embed
            return self.time_pos_embed(x + self.pos_embed)
        elif L < self.origin_L:
            # return x + self.pos_embed[:, -L:, :]
            return self.time_pos_embed(x + self.pos_embed[:, -L:, :])
        else:
            concat_embed = self.pos_embed
            while concat_embed.shape[1] < L:
                concat_embed = torch.cat((concat_embed, self.pos_embed), dim=1)
            # return x + concat_embed[:, -L:, :]
            return self.time_pos_embed(x + concat_embed[:, -L:, :])

# =========== Encoder Block ========================================================================
class RNNEncoder(nn.Module):
    def __init__(
        self,
        in_channel,
        emb_size=48,
    ):
        super().__init__()

        self.rnn = nn.LSTM(in_channel, emb_size, 1, batch_first=True)
        

    def forward(self, x):
        out, (_, _) = self.rnn(x.squeeze(1))
        return out
    
class EmbConvBlock(nn.Module):
    def __init__(
        self, 
        in_channel, 
        T_kernel_size, # 8,
        emb_size=48,
        hidden_size=48*4
    ):
        super().__init__()
        
        # Input shape: (N, L, C)
        self.liner = nn.Sequential(
            # CheckShape("Before in"),
            CheckShape(None, key=lambda x: x.unsqueeze(1)), #(N, 1, L, C)
            # Temporal
            nn.Conv2d(1, hidden_size, kernel_size=[T_kernel_size, 1], padding='same'), 
            # nn.Conv2d(1, hidden_size, kernel_size=[T_kernel_size, 1], padding='same', dilation=2), # no warning
            nn.BatchNorm2d(hidden_size), 
            nn.GELU(),
            # Spatial
            nn.Conv2d(hidden_size, emb_size, kernel_size=[1, in_channel], padding='valid'), 
            nn.BatchNorm2d(emb_size), 
            nn.GELU(),
            CheckShape(None, key=lambda x: torch.permute(x, (0, 3, 2, 1)).squeeze(1)), # (N, L, C)
        )

    def forward(self, x):
        # Input shape: (N, L, C)
        out = self.liner(x)
        return out

class EmbConvBlockRGB(nn.Module):
    def __init__(
        self, 
        in_channel, 
        T_kernel_size, # 8,
        emb_size=48,
        hidden_size=48*4
    ):
        super().__init__()
        
        # Input shape: (N, L, C)
        self.liner = nn.Sequential(
            # Temporal
            # nn.Conv2d(1, hidden_size, kernel_size=[T_kernel_size, 1], padding='same'), 
            nn.Conv2d(3, hidden_size, kernel_size=[T_kernel_size, 1], padding='same', dilation=2), # no warning
            nn.BatchNorm2d(hidden_size), 
            nn.GELU(),
            # Spatial
            nn.Conv2d(hidden_size, emb_size, kernel_size=[1, in_channel], padding='valid'), 
            nn.BatchNorm2d(emb_size), 
            nn.GELU(),
            CheckShape(None, key=lambda x: torch.permute(x, (0, 3, 2, 1)).squeeze(1)), # (N, L, C)
        )

    def forward(self, x):
        # Input shape: (N, L, C)
        out = self.liner(x)
        return out

# =========== Position Embedding ========================================================================
class tAPE(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):
        super(tAPE, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)  # positional encoding
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
        pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
        pe = scale_factor * pe.unsqueeze(0)
        self.register_buffer('pe', pe)  # this stores the variable in the state_dict (used for non-trainable variables)

    def forward(self, x): # N, L, C
        # print(torch.min(self.pe[:, :, :]), torch.max(self.pe[:, :, :])) # [-1, 1]
        # exit()
        x = x + self.pe[:, -x.shape[1]:, :]
        return self.dropout(x)
    
# =========== Multi-Head Self-Attention ========================================================================
def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)


def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)

class EmbedMatch(nn.Module):
    def __init__(self, in_embed_size=768, out_embed_size=384, num_heads=12, dropout=0.1):
        super().__init__()
        self.proj_k = nn.Linear(in_embed_size, out_embed_size)
        self.proj_v = nn.Linear(in_embed_size, out_embed_size)
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads

        self.proj = nn.Linear(out_embed_size, out_embed_size)
        self.norm = nn.LayerNorm(out_embed_size, eps=1e-6)

        self.fc = FeedForward(out_embed_size, out_embed_size*4, dropout=dropout)

        self.out_embed_size = out_embed_size

    def forward(self, x, q, mask=None, before_fuse_out=False):
        '''
        x: vit embedding, (N, L, 768),
        q: nlp embedding, (N, 384)
        return: (N, 384)
        '''
        N, L, _ = x.shape
        q = q.unsqueeze(1).expand(N, L, self.out_embed_size)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        k, v = self.proj_k(x), self.proj_v(x)
        if before_fuse_out:
            out_v = v
        
        # attn operations
        x_mean = torch.mean(v, dim=1)
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(nn.functional.softmax(scores, dim=-1))

        # intermediate output
        if before_fuse_out:
            return out_v, torch.sum(scores, dim=1)

        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        h = torch.mean(h, dim=1) # (B, D)
        return self.fc(self.norm(self.drop(self.proj(h))+x_mean))

class MultiHeadedSelfAttention(nn.Module):
    """Multi-Headed Dot Product Attention"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None # for visualization

        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim, eps=1e-6)

        self.fc = FeedForward(dim, dim*4, dropout=dropout)

    def forward(self, x, mask=None):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(nn.functional.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return self.fc(self.norm(self.drop(self.proj(h))+x))

class MultiHeadedSelfAttentionMS(nn.Module):
    """Multi-Headed Dot Product Attention"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None # for visualization

        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim, eps=1e-6)

        self.fc = FeedForward(dim, dim*2, dropout=dropout)

    def forward(self, x, mask=None):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # calculate importance
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = x, x, x
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(nn.functional.softmax(scores, dim=-1))

        # calculate relevance
        target_embed = self.fc.fc_liner[2].weight
        TE, _ = target_embed.shape
        target_embed = target_embed.view(self.n_heads, TE // self.n_heads, -1).unsqueeze(1).unsqueeze(0)

        # dot prod
        v_ = v.unsqueeze(-1)
        imp = torch.sum(target_embed * v_, dim=3)

        # denominator (complete cos similarity)
        den = torch.sqrt(torch.sum(target_embed ** 2, dim=-2)) * torch.sqrt(torch.sum(v_ ** 2, dim=-2))
        imp = imp / den

        # take highest correlation
        imp = torch.max(imp, dim=-1).values.unsqueeze(-2)
        scores = scores + imp

        # final output
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return self.norm(self.drop(self.proj(h))+x)
    
################################################

class FeedForward(nn.Module):
    def __init__(self, emb_size, hidden_size, dropout=0.1, add_norm=True, vae_out=False):
        super().__init__()
        self.add_norm = add_norm

        last_out = nn.Linear if not vae_out else VAE_Latent

        self.fc_liner = nn.Sequential(
            nn.Linear(emb_size, hidden_size),
            nn.GELU(),
            # nn.Dropout(p=dropout),
            # nn.Linear(hidden_size, emb_size),
            last_out(hidden_size, emb_size),
            nn.Dropout(p=dropout),
        )

        self.LayerNorm = nn.LayerNorm(emb_size, eps=1e-6)

    def forward(self, x):
        out = self.fc_liner(x)
        if self.add_norm:
            return self.LayerNorm(x + out)
        return out

# =========== Main Transformer Layer ========================================================================
    
class Transformer(nn.Module):
    def __init__(
        self, 
        num_layers=1,
        emb_size=768,
        num_heads=12,
        dropout=0.1,
    ):
        super().__init__()

        # self.blocks = nn.ModuleList([
        #     nn.Sequential(
        #         MultiHeadedSelfAttention(emb_size, num_heads, dropout),
        #         # MultiHeadedSelfAttentionMS(emb_size, num_heads, dropout),
        #     ) for _ in range(num_layers)
        # ])

        self.blocks = nn.ModuleList([
            MultiHeadedSelfAttention(emb_size, num_heads, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        # Input shape: (N, L, C)
        for block in self.blocks:
            x = block(x)
        return x # (N, L, E), (N, E)
    
class SiT(nn.Module): # Signal Transformer (SiT)
    def __init__(
        self, 
        num_layers=6,
        emb_size=384,
        num_heads=12,
        dropout=0.1,
        # data related
        in_channel=65
    ):
        super().__init__()

        self.cls = nn.Parameter(torch.rand(1, 1, emb_size), requires_grad=True)

        # Embeddings (Potential Improvement here)
        self.emb = nn.Sequential( # (N, L, C)
            # Patching
            # EmbConvBlock(in_channel, 8, emb_size, hidden_size=emb_size*4),
            EmbConvBlockRGB(in_channel, 8, emb_size, hidden_size=emb_size*4),
            # CheckShape("After embed"),
            # position embedding
            tAPE(emb_size, dropout=dropout, max_len=2048),
        ) # (N, L, E)

        # transformers blocks
        self.transformer = Transformer(
            num_layers=num_layers,
            emb_size=emb_size,
            num_heads=num_heads,
            dropout=dropout
        )

    def forward(self, x): # Input shape: (N, 3, L, F)
        out = self.emb(x) # (N, L, E)
        N, L, E = out.shape
        out = torch.cat((self.cls.expand(N, 1, E), out), dim=1) # prepend [CLS] token
        return self.transformer(out) # (N, L+1, E)

class TemporalFusion(nn.Module):
    def __init__(self, num_neurons=768, query_size=384, fuse_method='msitf', dropout=0.1):
        super().__init__()
        self.fuse_method = fuse_method # mean, last, msitf

        if fuse_method == 'msitf':
            # self.v = nn.Linear(num_neurons, query_size)

            # importance, learn-to-drop
            self.cd = nn.Sequential(
                nn.Linear(num_neurons, 2),
                nn.Sigmoid()
            )
            self.t_bound = 5e-1

            # relevance
            self.k = nn.Linear(num_neurons, query_size)
            self.v = nn.Linear(num_neurons, query_size)
            self.q = nn.Linear(query_size, query_size)

            # recency
            self.decay_r = 0.997

            # output fc
            self.out_fc = FeedForward(query_size, query_size*4, dropout=dropout, vae_out=True)

    def forward(self, x, query, mask=None, return_scores=False): # x: (N, L, E)
        if self.fuse_method == 'mean':
            return torch.mean(x, dim=1)
        elif self.fuse_method == 'last':
            return x[:, :, -1]
        
        # memory stream based fusion
        N, L, E = x.shape
        # recency
        recency = torch.tensor([self.decay_r ** p for p in range(L)], requires_grad=False).bfloat16().to(DEVICE)
        recency = torch.flip(recency, [0]).view(1, -1) # 1, L

        # relevance
        # query: N, E_q
        query = query.unsqueeze(1) # N, 1, E_q
        q = self.q(query)
        k = self.k(x)
        v = self.v(x)
        den = torch.norm(q, dim=-1)*torch.norm(k, dim=-1) # N, L
        attn = torch.sum(k*q, dim=-1) / den # N, L
        relevance = torch.softmax(attn, dim=-1)# N, L

        # importance
        prob = self.cd(x) # N, L, 2
        log_prob = torch.log(prob)
        if self.training:
            eps = -torch.log(-torch.log(torch.rand(prob.shape))).bfloat16().to(DEVICE) # gumbel distribution
        else:
            eps = torch.tensor(0.577).bfloat16().to(DEVICE) # empirical mean of gumbel distribution
        log_prob = (log_prob + eps) / self.t_bound # N, L, 2
        prob_matrix = torch.exp(log_prob) / torch.exp(log_prob).sum(dim=-1, keepdim=True) # N, L, 2
        # mask = 0.5 < prob_matrix
        # prob_matrix = prob_matrix * mask # N, L, 2
        importance = prob_matrix[:, :, 1] # N, L

        # integrate scores, normalize, and weighted sum
        retrieval = recency+relevance+importance # N, L
        retrieval = torch.softmax(retrieval, dim=-1) # N, L
        # retrieval = retrieval / torch.sum(retrieval, dim=1, keepdim=True) # N, L
        final_out = torch.sum(v*retrieval.unsqueeze(-1), dim=1) # N, E_q

        final_out = self.out_fc(final_out)

        # return
        if return_scores:
            return v, recency, relevance, importance
        return final_out

# test shapes of input and output
if __name__ == '__main__':
    x = torch.rand((2, 4, 768))
    y = torch.rand((2, 384))
    # m = EmbedMatch()
    # out = m(x, y)
    m = TemporalFusion()
    out, rec, rel, imp = m(x, query=y, return_scores=True)
    print(out.shape)
    print(rec.shape)
    print(rel.shape)
    print(imp.shape)

    # num_sample = 3
    # seq_length = 256
    # in_channel = 65

    # x = torch.rand((num_sample, 3, seq_length, in_channel)) # N, 3, L, F

    # model = SiT(
    #     num_layers=6,
    #     emb_size=384,
    #     # data related
    #     in_channel=in_channel
    # )

    # # count number of params
    # pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # print("Num param:", pytorch_total_params)

    # print("x shape:", x.shape)
    # y = model(x)
    # print('y shape:', y.shape) # (3, 257, 384)
    
