import numpy.random as npr
import torch
import torch.nn as nn
import math

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce


class Patch_embedding(torch.nn.Module):
    def __init__(self, patch_size, in_channels, hdim, max_len, drop_rate):
        super(Patch_embedding, self).__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.idim = patch_size * patch_size * in_channels
        self.hdim = hdim
        self.max_len = max_len
        self.pos_emb = nn.Parameter(1e-1 * torch.tensor(npr.randn(max_len, hdim), dtype=torch.float32))  

        self.linear_proj = nn.Sequential(
            nn.Conv2d(in_channels, hdim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.ln = nn.LayerNorm(hdim)
        self.dropout = nn.Dropout(drop_rate)
    
    def forward(self, x):  
        input_emb = self.linear_proj(x)
        patch_emb = input_emb + self.pos_emb
        patch_emb = self.dropout(patch_emb)
        return self.ln(patch_emb), patch_emb
        
class Embeddings(torch.nn.Module):
    def __init__(self,vocab_size,max_len,emb_size,h_size, drop_rate):
        super(Embeddings,self).__init__()
        
        self.token_embeds=nn.Embedding(vocab_size,emb_size,padding_idx=0)
        self.pos_embeds=nn.Embedding(max_len,emb_size)
        self.layer_norm=nn.LayerNorm(h_size)
            
        self.project=nn.Linear(emb_size,h_size)
        self.dropout = nn.Dropout(drop_rate)
        
    def forward(self,input_data,pos):
        rep=self.token_embeds(input_data)
        pos=self.pos_embeds(pos)
      
        output=rep+pos
        output=self.project(output)
        output = self.dropout(output)
        
        return self.layer_norm(output), output


class FC(torch.nn.Module):
    def __init__(self, hdim, drop_rate=0.):
        super(FC, self).__init__()
        self.hdim = hdim
        self.act = torch.nn.GELU() 
        self.fc = nn.Sequential(nn.Linear(hdim, hdim), nn.Dropout(drop_rate), self.act, nn.Linear(hdim,hdim), nn.Dropout(drop_rate))
        self.ln = nn.LayerNorm(hdim)

    def forward(self, x):  
        res = self.fc(x)
        return res
    

class ClassficationHead_vit(torch.nn.Module):
    def __init__(self, hdim, num_class):
        super(ClassficationHead_vit, self).__init__()
        self.hdim = hdim
        self.num_class = num_class
        self.fc = nn.Linear(hdim, num_class)
        self.seqpool = nn.Linear(hdim, 1)
        self.ln = nn.LayerNorm(hdim)

    def forward(self, x): 
        # Pooling strategy as in https://arxiv.org/abs/2104.05704 
        res = self.seqpool(x).permute(0,1,3,2) 
        res = torch.softmax(res, -1) 
        res = res @ x 
        res = torch.mean(res, 2) 
        res = self.ln(res)
        res = self.fc(res) 
        return res

class ClassficationHead(torch.nn.Module):
    def __init__(self, hdim, num_class, drop_rate=0.):
        super(ClassficationHead, self).__init__()
        self.hdim = hdim
        self.num_class = num_class
        self.fc = nn.Sequential(nn.Linear(hdim, num_class), nn.Dropout(drop_rate))
        self.ln = nn.LayerNorm(hdim)

    def forward(self, x, input_mask):
        input_mask = input_mask.unsqueeze(-1).unsqueeze(1)
        res = x* input_mask
        res = torch.mean(res, 2)
        res = self.ln(res)
        res = self.fc(res)
        return res

def kernel_ard(X1, X2, log_ls, log_sf):
    X1 = X1 * torch.exp(-log_ls).unsqueeze(1)
    X2 = X2 * torch.exp(-log_ls).unsqueeze(1)
    factor1 = torch.sum(X1.pow(2), -1)
    factor2 = torch.sum(X2.pow(2), -1)
    return torch.exp(log_sf).unsqueeze(1) * \
        torch.exp(-0.5* (factor1.unsqueeze(3) + factor2.unsqueeze(2) -2* X1 @ X2.permute(0,1,3,2)))


def kernel_exp(X1, X2, log_ls, log_sf):
    X1 = X1 * torch.exp(-log_ls).unsqueeze(1) 
    X2 = X2 * torch.exp(-log_ls).unsqueeze(1)
    return torch.exp(log_sf).unsqueeze(1)* torch.exp(X1 @ X2.permute(0,1,3,2))


def scale_dot(X1, X2):
    dk = X2.shape[3]
    return torch.softmax(X1 @ X2.permute(0,1,3,2)/ (math.sqrt(dk)), 3)


class SGP_LAYER(nn.Module):
    def __init__(self, device, num_heads, max_len, hdim, kernel_type, sample_size, jitter, keys_len, drop_rate, flag_sgp, inference_mode):
        super(SGP_LAYER, self).__init__()
        self.max_len = max_len
        self.num_heads = num_heads
        self.hdim = hdim
        self.vdim = self.hdim // self.num_heads
        self.dq = self.vdim
        self.flag_sgp = flag_sgp
        self.keys_len = keys_len
        self.drop_rate = drop_rate
        self.K_k_beta_k_beta = None
        self.inference_mode = inference_mode
        self.cache_inverse1 = None
        self.cache_inverse2 = None
        
        if kernel_type == 'exponential':
            self.log_sf = nn.Parameter(-4. + 0.* torch.tensor(npr.randn(self.num_heads,1), dtype=torch.float32)) 
            self.log_ls = nn.Parameter(4. + 1.* torch.tensor(npr.randn(self.num_heads,self.dq), dtype=torch.float32)) 
        elif kernel_type == 'ard':
            self.log_sf = nn.Parameter(0. + 0.* torch.tensor(npr.randn(self.num_heads,1), dtype=torch.float32))
            self.log_ls = nn.Parameter(0. + 1.* torch.tensor(npr.randn(self.num_heads,self.dq), dtype=torch.float32)) 
        elif kernel_type == 'scale_dot':
            pass
        else:
            raise ValueError("The argument 'kernel_type' should be either 'exponential', 'ard', or 'scale_dot'.")
        
        self.sample_size = sample_size
        self.jitter = jitter
        self.device = device
        self.kernel_type = kernel_type 
        
        self.fc_qkv = nn.Linear(self.hdim, 2* self.num_heads* self.vdim, bias=False)
        if self.kernel_type == 'scale_dot':
            self.fc_k = nn.Linear(self.hdim, self.hdim, bias=False)

        if self.flag_sgp:
            self.v = nn.Parameter(torch.tensor(npr.randn(self.num_heads, 1, self.keys_len, self.vdim), dtype=torch.float32))
            self.s_sqrt_ltri = nn.Parameter( torch.tensor(npr.randn(self.num_heads, 1, self.vdim, self.keys_len, self.keys_len), dtype=torch.float32))
            self.log_s_sqrt_diag = nn.Parameter( torch.tensor(npr.randn(self.num_heads, 1, self.vdim, self.keys_len), dtype=torch.float32))
        
        self.W_O = nn.Sequential(nn.Linear(self.hdim, self.hdim), nn.Dropout(self.drop_rate))
      
    def get_q_k_v_ssqrt(self, x, cur_k):
        
        q, v_gamma = self.fc_qkv(x).view(x.shape[0], x.shape[1], self.num_heads, 2* self.vdim).permute(0,2,1,3).chunk(chunks=2, dim=-1)
        if self.kernel_type == 'scale_dot':
            k_gamma = self.fc_k(x).view(x.shape[0], x.shape[1], self.num_heads, self.vdim).permute(0,2,1,3)
        else:
            k_gamma = q
        if self.flag_sgp:
            W_qk = self.fc_qkv.weight[:self.hdim]
            k_beta = W_qk.view(self.num_heads, 1, 1, self.vdim, self.hdim) @ cur_k.unsqueeze(-1) 
            k_beta = k_beta.squeeze(-1).permute(1,0,2,3) 
            v_beta = self.v.permute(1,0,2,3)
            log_ssqrt = self.log_s_sqrt_diag.permute(1,0,2,3) 
            return q, k_gamma, k_beta, v_gamma, v_beta, log_ssqrt  
        else:
            return q, k_gamma, v_gamma
        
    def forward(self, x, cur_k):
        # We set W_q = W_k to maintain a valid symmetric deep kernel, so q = k_gamma below when kernel_type='exponential' or 'ard'.
        # We can use different projection matrices if necessary.
        if self.flag_sgp:
            q, k_gamma, k_beta, v_gamma, v_beta, log_ssqrt = self.get_q_k_v_ssqrt(x, cur_k)
        else:
            q, k_gamma, v_gamma = self.get_q_k_v_ssqrt(x, cur_k)
            
        if self.kernel_type == 'exponential':
            if not self.flag_sgp:
                K_qq = kernel_exp(q, q, self.log_ls, self.log_sf)  # [bs, num_heads, max_len, max_len]
            else:
                K_qq, K_qk_beta = kernel_exp(q, torch.cat([q, k_beta.tile(q.shape[0],1,1,1)], 2), \
                    self.log_ls, self.log_sf).tensor_split([self.max_len,],-1) # [bs, num_heads, max_len, max_len + keys_len]
                K_k_beta_k_gamma = K_qk_beta.permute(0,1,3,2)

                if self.K_k_beta_k_beta != None:
                    K_k_beta_k_beta = self.K_k_beta_k_beta
                else:
                    K_k_beta_k_beta = kernel_exp(k_beta, k_beta, self.log_ls, self.log_sf)
                    if self.inference_mode:
                        self.K_k_beta_k_beta = K_k_beta_k_beta
            K_qk_gamma = K_qq
            if self.flag_sgp:    
                K_k_gamma_k_gamma = K_qq
        elif self.kernel_type == 'ard':
            if not self.flag_sgp:
                K_qq = kernel_ard(q, q, self.log_ls, self.log_sf)  
            else:
                K_qq, K_qk_beta = kernel_ard(q, torch.cat([q, k_beta.tile(q.shape[0],1,1,1)], 2), \
                    self.log_ls, self.log_sf).tensor_split([self.max_len,],-1) 
                K_k_beta_k_gamma = K_qk_beta.permute(0,1,3,2)

                if self.K_k_beta_k_beta != None:
                    K_k_beta_k_beta = self.K_k_beta_k_beta
                else:
                    K_k_beta_k_beta = kernel_ard(k_beta, k_beta, self.log_ls, self.log_sf)
                    if self.inference_mode:
                        self.K_k_beta_k_beta = K_k_beta_k_beta
            K_qk_gamma = K_qq
            if self.flag_sgp:    
                K_k_gamma_k_gamma = K_qq
        elif self.kernel_type == 'scale_dot':
            K_qk_gamma = scale_dot(q, k_gamma)
        else:
            raise ValueError("The argument 'kernel_type' should be either 'exponential', 'ard' or 'scale_dot'.")
        
        if not self.flag_sgp: 
            mean = K_qk_gamma @ v_gamma
            samples = mean.unsqueeze(2) 
            samples = torch.flatten(samples.permute(0,2,3,1,4),-2,-1) 
            samples = self.W_O(samples)
            return samples, None
        else:
            s_sqrt = torch.exp(log_ssqrt) 
            s_sqrt_diag = torch.diag_embed(s_sqrt) 
            s_sqrt_local = s_sqrt_diag + torch.tril(self.s_sqrt_ltri.permute(1,0,2,3,4), diagonal=-1)

            if self.inference_mode and self.cache_inverse1 == None:
                K_kk_inverse = torch.linalg.inv(K_k_beta_k_beta + self.jitter* torch.eye(K_k_beta_k_beta.shape[2], device=self.device))
                self.cache_inverse1 = K_kk_inverse
                K_kk_inverse = K_kk_inverse.unsqueeze(2)
                chol_K_kk = torch.linalg.cholesky(K_k_beta_k_beta + self.jitter* torch.eye(K_k_beta_k_beta.shape[2], device=self.device)).unsqueeze(2)
                self.cache_inverse2 = K_kk_inverse @ chol_K_kk @ s_sqrt_local @ s_sqrt_local.permute(0,1,2,4,3) @ chol_K_kk.permute(0,1,2,4,3) @ K_kk_inverse - K_kk_inverse

            # Notice here we make diagonal approximation of the full covariance to accelerate sampling. 
            # Empirically, it doesn't seem to hurt the performance.
            chol_covar1 = torch.diagonal(K_qq.unsqueeze(2) , dim1=3, dim2=4).permute(0,1,3,2).unsqueeze(2)
            if self.inference_mode:
                # During inference, using cached inverse instead of solving linear systems to speed up.
                mean1 = K_qk_gamma @ v_gamma
                mean = mean1 - K_qk_beta @ (self.cache_inverse1 @ (K_k_beta_k_gamma @ v_gamma)) + K_qk_beta @ v_beta 
                chol_covar = (chol_covar1 + ((K_qk_beta.unsqueeze(2) @ self.cache_inverse2) * K_qk_beta.unsqueeze(2)).sum(-1).permute(0,1,3,2).unsqueeze(2)).pow(0.5)
            else:
                jitter = self.jitter
                while True:
                    try:
                        chol_K_kk = torch.linalg.cholesky(K_k_beta_k_beta + jitter* torch.eye(K_k_beta_k_beta.shape[2], device=self.device))
                        break
                    except Exception:
                        jitter = jitter * 10

                v1 = torch.linalg.solve_triangular(chol_K_kk, K_k_beta_k_gamma, upper=False)
                v2 = torch.linalg.solve_triangular(chol_K_kk, K_k_beta_k_gamma @ v_gamma, upper=False)
                v3 = v1.unsqueeze(2).permute(0,1,2,4,3) @ s_sqrt_local

                mean1 = K_qk_gamma @ v_gamma
                mean = mean1 - v1.permute(0,1,3,2) @ v2 + K_qk_beta @ v_beta 
                
                chol_covar2 = v3.pow(2).sum(-1).permute(0,1,3,2).unsqueeze(2) - \
                    v1.unsqueeze(2).permute(0,1,2,4,3).pow(2).sum(-1).permute(0,1,3,2).unsqueeze(2)
                chol_covar = (chol_covar1 + chol_covar2).pow(0.5)
                
            samples = mean.unsqueeze(2) + chol_covar * torch.randn((mean.shape[0], mean.shape[1], self.sample_size, mean.shape[2], mean.shape[3]), device=self.device)   
            samples = torch.flatten(samples.permute(0,2,3,1,4),-2,-1) 
            samples = self.W_O(samples) 

            if self.inference_mode:
                return samples, None
            else:
                kl = -0.5* self.keys_len* self.vdim * self.num_heads 
                kl += 0.5* torch.mean(torch.sum(s_sqrt_local.pow(2), (-1,-2,-3,-4)))            
                kl += 0.5* torch.mean(torch.sum(v_beta.permute(0,1,3,2).unsqueeze(3) @ K_k_beta_k_beta.unsqueeze(2) @ v_beta.permute(0,1,3,2).unsqueeze(4), (1,2))) 
                second_term = v2.permute(0,1,3,2).unsqueeze(3) @ v2.permute(0,1,3,2).unsqueeze(4)
                temp = v_gamma.permute(0,1,3,2).unsqueeze(3) @ mean1.permute(0,1,3,2).unsqueeze(4) - second_term
                kl += 0.5* torch.mean(torch.sum(temp, (1,2)))
                kl -= torch.mean(torch.sum(log_ssqrt, (-1, -2, -3))) 
                return samples, kl
            
class ViT(torch.nn.Module):
    def __init__(self, device, depth, patch_size, in_channels, max_len, num_class, hdim, num_heads, sample_size, jitter, drop_rate, keys_len, kernel_type, flag_sgp, inference_mode=False):
        super(ViT, self).__init__()
        self.hdim = hdim
        self.num_heads = num_heads
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.max_len = max_len
        self.num_class = num_class
        self.sample_size=sample_size
        self.depth = depth
        self.jitter = jitter
        self.flag_sgp = flag_sgp
        if not self.flag_sgp:
            self.sample_size = 1
        self.keys_len = keys_len
        self.kernel_type = kernel_type
        self.drop_rate = drop_rate
        self.inference_mode = inference_mode

        self.patch_embedding = Patch_embedding(patch_size=patch_size, in_channels=in_channels, hdim=hdim, max_len=max_len, drop_rate=drop_rate)
        
        self.class_head = ClassficationHead_vit(hdim=hdim, num_class=num_class)

        self.device = device

        self.ln = nn.LayerNorm(hdim)

        self.keys = nn.ParameterList([nn.Parameter(torch.tensor(npr.randn(self.num_heads, 1, self.keys_len, self.hdim), dtype=torch.float32)) for i in range(self.depth)])

        self.sgp_layer_list = nn.ModuleList([SGP_LAYER(device=device, num_heads=num_heads, max_len=max_len, hdim=hdim, kernel_type=self.kernel_type, drop_rate=self.drop_rate, \
            keys_len=self.keys_len, sample_size=self.sample_size, jitter=jitter, flag_sgp=flag_sgp, inference_mode=self.inference_mode)])
        self.mlp_layer_list = nn.ModuleList([FC(hdim=hdim, drop_rate=self.drop_rate)])

        for i in range(1, depth):
            self.sgp_layer_list.append(SGP_LAYER(device=device, num_heads=num_heads, max_len=max_len, hdim=hdim,\
                kernel_type=self.kernel_type, drop_rate=self.drop_rate, keys_len=self.keys_len, sample_size=1, jitter=jitter, flag_sgp=flag_sgp, inference_mode=self.inference_mode))
            self.mlp_layer_list.append(FC(hdim=hdim, drop_rate=self.drop_rate))

    def forward(self, X):
        patch_emb_ln, patch_emb = self.patch_embedding.forward(X) 
        z, total_kl = self.sgp_layer_list[0].forward(patch_emb_ln, self.keys[0])
        
        z_prime = patch_emb.unsqueeze(1) + z 
        z_ln = self.ln(z_prime)
        
        z = self.mlp_layer_list[0].forward(z_ln) + z_prime 

        cur_k = None
        if self.flag_sgp:
            cur_k = self.mlp_layer_list[0].forward(self.keys[1]) + self.keys[1] 
        for i in range(1, self.depth):
            z_prev = z.reshape(-1, z.shape[-2], z.shape[-1]) 
            z_ln = self.ln(z_prev) 
            if self.flag_sgp:
                cur_k = self.ln(cur_k) 
            z, kl = self.sgp_layer_list[i].forward(z_ln, cur_k)
            if self.flag_sgp and not self.inference_mode:
                total_kl += kl
            z_prime = z_prev.unsqueeze(1) + z  
            z_ln = self.ln(z_prime)  
            z = self.mlp_layer_list[i].forward(z_ln) + z_prime  
            if self.flag_sgp and i < self.depth-1:
                cur_k = self.mlp_layer_list[i].forward(self.keys[i+1]) + self.keys[i+1] 
            
        logits = self.class_head.forward(z).squeeze(1) 
        return logits, total_kl
    def loss(self, X, y, anneal_kl=1.):
        logits, total_kl = self.forward(X)
        ce_loss = nn.CrossEntropyLoss()
        y = torch.unsqueeze(y,1)
        y = torch.tile(torch.unsqueeze(y, 1), (1, self.sample_size, 1)).view(-1, y.shape[1])
        neg_ElogPyGf = ce_loss(logits.view(-1, self.num_class), y.view(-1))
        if self.flag_sgp and total_kl.item() > 0:
            loss = neg_ElogPyGf + anneal_kl* total_kl
        else:
            loss = neg_ElogPyGf
        return loss