import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torch.distributions import Categorical
from tqdm import tqdm

def forward_aW_diag(a_mean, a_var, weight, bias, w_var, b_var):
    """
    compute mean and covariance of h = a @ W^T + b when posterior has diag covariance
    
    ----- Input -----
    a_mean: [N, D_in] mean(a)
    a_var: [N, D_in] a_var[i] = var(a_i)
    weight: [D_out, D_in] W
    bias: [D_out, ] b
    b_var: [D_out, ] b_var[k]: var(b_k)
    w_var: [D_out, D_in] w_cov[k][i]: var(w_ki)
    ----- Output -----
    h_mean: [N, D_out]
    h_var: [N, D_out] h_var[k] = var(h_k)
    """
    
    # calculate mean(h)
    h_mean = F.linear(a_mean, weight, bias)
    
    # calculate var(h)
    weight_mean2_var_sum = weight ** 2 + w_var # [D_out, D_in]
    h_var = a_mean **2 @ w_var.T + a_var @ weight_mean2_var_sum.T + b_var
    
    return h_mean, h_var

def forward_linear_diag_Bayesian_weight(e_mean, e_var, w_mean, w_var, bias = None):
    """
    Pass a distribution with diagonal covariance through a Bayesian linear layer with diagonal covariance

    Given e ~ N(e_mean, e_cov), W ~ N(W_mean, W_cov), and h = eW.T, calculate output variance
    
    We only make the weight Bayesian, bias is treated determinstically
    
    optional input: bias. It only effect the output mean
    
    Input
        e_mean: [B, T, D_in] embedding mean
        e_var: [B, T, D_in] embedding variance
        w_mean: [D_out, D_in] weight mean
        w_var: [D_out, D_in] weight covariance, w_cov[k][i]: var(w_ki)
    Output
        h_mean: [B, T, D_out]
        h_var: [B, T, D_out] h_var[k] = var(h_k)
    """

    # calculate mean(h)
    h_mean = F.linear(e_mean, w_mean, bias)
    
    # calculate var(h)
    weight_mean2_var_sum = w_mean ** 2 + w_var # [D_out, D_in]
    h_var = e_mean **2 @ w_var.T + e_var @ weight_mean2_var_sum.T

    return h_mean, h_var

def forward_linear_diag_determinstic_weight(e_mean, e_var, weight, bias = None):
    """
    Pass a distribution with diagonal covariance through a linear layer

    Given e ~ N(e_mean, e_var) and determinstic W and bias, calculate the mean and variance h = eW.T + b.
    
    Note that as we always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.
    
    Input
        e_mean: [B, T, D_in] embedding mean
        e_var: [B, T, D_in] embedding variance
        w_mean: [D_out, D_in] weight 
    Output
        h_mean: [B, T, D_out]
        h_var: [B, T, D_out] h_var[k] = var(h_k)
    """
    
    h_mean = F.linear(e_mean, weight, bias)
    h_var = F.linear(e_var, weight ** 2, None)
    
    return h_mean, h_var

@torch.enable_grad()
def forward_activation_diag(activation_func, h_mean, h_var):
    """
    Pass a distribution with diagonal covariance through an activation layer. 

    Given h ~ N(h_mean, h_cov), g(·), where h_cov is a diagonal matrix,
    approximate the distribution of a = g(h) as 
    a ~ N(g(h_mean), g'(h_mean)^T h_var g'(h_mean))
    
    Input
        activation_func: g(·)
        h_mean: [B, T, D] input mean
        h_var: [B, T, D] input variance
    
    Output
        a_mean: [B, T, D]
        a_var: [B, T, D]
    """

    h_mean_grad = h_mean.detach().clone().requires_grad_()
    
    a_mean = activation_func(h_mean_grad)
    a_mean.retain_grad()
    a_mean.backward(torch.ones_like(a_mean)) #[B, T, D]
    
    nabla = h_mean_grad.grad #[B, T, D]
    a_var = nabla ** 2 * h_var
    
    return a_mean.detach(), a_var
    
def forward_layer_norm_diag(e_mean, e_var, ln_weight, ln_eps):
    """
    Pass a distribution with diagonal covariance through LayerNorm layer
    
    Input
        e_mean: mean of input distribution [B, T, D]
        e_var: variance of input distribution [B, T, D]
        ln_weight: layer norm scale factor
        ln_eps: layer norm eps
    
    Output
        output_var [B, T, D]
    """

    # calculate the var
    input_mean_var = e_mean.var(dim=-1, keepdim=True, unbiased=False) # [B, T, 1]
    scale_factor = (1 / (input_mean_var + ln_eps)) * ln_weight **2 # [B, T, D]
    output_var = scale_factor * e_var # [B, T, D]
    
    del input_mean_var
    del scale_factor
    
    return output_var

def forward_value_cov_Bayesian_W(W_v, W_v_var, input_mean, input_var, n_h, D_v):
    """
    Given value matrix W_v ~ N(mean(W), var(W)) and input E ~ N(mean(E), var(E))
    Compute the covariance of output v = W_v @ E 

    
    Input: 
        n_h: number of attention heads
        D_v: dimension of value, n_h * D_v = D
        W_v: value weight matrix [D, D], which can be reshaped into [n_h, D_v, D]
        W_v_var: variance of value matrix, [D, D]
        input_mean: mean of input [B, T, D]
        input_var: variance of input variance [B, T, D]
        
    Output:
        v_cov [B, T, n_h, D_v, D_v]
    """

    B, T, D = input_var.size()

    ## compute general covariance 
    W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D) 
        #[n_h, D_v, D] -> [1, 1, n_h, D_v, D]
    input_var_reshaped = input_var.reshape(B, T, 1, 1, D)
        # [B, T, D] -> [B, T, 1, 1, D]
    v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
        # [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
    v_cov = torch.matmul(W_v_reshaped, v_cov)
        #  [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v]  -> [B, T, n_h, D_v, D_v]
    
    ## add missing part for variance
    W_v_var_reshaped = W_v_var.reshape(1, 1, n_h, D_v, D) 
    input_var_plus_mean_square = input_var_reshaped + input_mean.reshape(B, T, 1, 1, D)**2 #[B, T, 1, 1, D]
    extra_var_term = torch.sum(input_var_plus_mean_square * W_v_var_reshaped, dim=[4]) # [B, T, n_h, D_v, D] -> [B, T, n_h, D_v]
    v_cov = v_cov + torch.diag_embed(extra_var_term) 
    

    torch.cuda.empty_cache()
    
    return v_cov

def forward_value_cov_determinstic_W(W_v, input_var, n_h, D_v):
    """
    Given determinstic value matrix W_v and input E ~ N(mean(E), var(E))
    Compute the covariance of output v = W_v @ E 

    
    Input: 
        n_h: number of attention heads
        D_v: dimension of value, n_h * D_v = D
        W_v: value weight matrix [D, D], which can be reshaped into [n_h, D_v, D]
        input_var: variance of input variance [B, T, D]
        
    Output:
        v_cov [B, T, n_h, D_v, D_v]
    """

    B, T, D = input_var.size()

    W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D) 
        #[n_h, D_v, D] -> [1, 1, n_h, D_v, D]
    input_var_reshaped = input_var.reshape(B, T, 1, 1, D)
        # [B, T, D] -> [B, T, 1, 1, D]
    v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
        # [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
    v_cov = torch.matmul(W_v_reshaped, v_cov)
        #  [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v]  -> [B, T, n_h, D_v, D_v]

    torch.cuda.empty_cache()
    
    return v_cov

def forward_QKV_cov(attention_score, v_cov):
    """
    given attention score (QK^T) and V ~ N(mean(V), cov(V))
    compute the covariance of output E = (QK^T) V
    
    Input:
        attention_score: [B, n_h, T, T] attention_score[t] is token t's attention score for all other tokens
        v_cov: [B, T, n_h, D_v, D_v] covariance of value
    Output:
        QKV_cov: [B, n_h, T, D_v, D_v] covariance of output E
    """
    
    B, T, n_h, D_v, _ = v_cov.size()

    QKV_cov = attention_score **2 @ v_cov.reshape(B, n_h, T, D_v * D_v) # [B, n_h, T, D_v * D_v]
        # v_cov [B, T, n_h, D_v, D_v] -> [B, n_h, T, D_v * D_v]
        # [B, n_h, T, T] @ [B, n_h, T, D_v * D_v]  -> [B, n_h, T, D_v * D_v]
    QKV_cov = QKV_cov.reshape(B, n_h, T, D_v, D_v)
        # [B, n_h, T, D_v * D_v] -> [B, n_h, T, D_v, D_v]
    
    torch.cuda.empty_cache()
    
    return QKV_cov

def forward_fuse_multi_head_cov(QKV_cov, project_W):
    """
    given concatanated multi-head embedding E ~ N(mean(E), cov(E)) and project weight matrix W
    compute variance of each output dimenison
    
    Input:
        QKV_cov: [B, n_h, T, D_v, D_v]
        project_W: [D, D]  D_out x D_in
        
    Output: 
        output_var [B, T, D]
    """
    
    B, n_h, T, D_v, _ = QKV_cov.size()
    D, _ = project_W.shape

    project_W_reshaped_1 = project_W.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, D_v, 1)
        # [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, D_v, 1]
    project_W_reshaped_2 = project_W.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, 1, D_v)
        # [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, 1, D_v]

    project_W_outer = torch.bmm(project_W_reshaped_1, project_W_reshaped_2).reshape(D, n_h, D_v, D_v) # [D, n_h, D_v, D_v]
    # [n_h * D, D_v, D_v] @ [n_h * D, 1, D_v] -> [n_h * D, D_v, D_v] ->  [D, n_h, D_v, D_v]


    output_var = torch.zeros((D, B, T), device = project_W_outer.device)

    for d in range(D):
        output_var[d] = torch.sum(project_W_outer[d].reshape([1, 1, n_h, D_v, D_v]) * QKV_cov.reshape([B, T, n_h, D_v, D_v]), dim=[2,3,4])

    del project_W_outer
    torch.cuda.empty_cache()

    
    return output_var.permute(1, 2, 0)

class LayerNorm_DBNN_Diag(nn.Module):
    
    def __init__(self, LayerNorm):
        super().__init__()
        
        self.LayerNorm = LayerNorm
    
    def forward(self, x_mean, x_var):
        
        with torch.no_grad():
        
            out_mean = self.LayerNorm.forward(x_mean)
            out_var = forward_layer_norm_diag(x_mean, x_var, self.LayerNorm.weight, 1e-5)
            
        return out_mean, out_var


class Classifier_DBNN_Diag(nn.Module):
    
    def __init__(self, classifier, w_var, b_var):
        super().__init__()
        
        self.weight = classifier.weight
        self.bias = classifier.bias
        self.w_var = w_var.reshape(self.weight.shape)
        self.b_var = b_var.reshape(self.bias.shape)
    
    def forward(self, x_mean, x_var):
        
        with torch.no_grad():
            h_mean, h_var = forward_aW_diag(x_mean, x_var, self.weight.data, self.bias.data, self.w_var, self.b_var)
        return h_mean, h_var

class MLP_DBNN_Diag(nn.Module):
    
    def __init__(self, MLP, determinstic = True, w_fc_var = None, w_proj_var = None):
        super().__init__()

        self.MLP = MLP
        self.determinstic = determinstic
        if not determinstic:
            self.w_fc_var = w_fc_var.reshape(self.MLP.c_fc.weight.shape)
            self.w_proj_var = w_proj_var.reshape(self.MLP.c_proj.weight.shape)

    def forward(self, x_mean, x_var):
        
        # first fc layer
        with torch.no_grad():
            if self.determinstic:
                h_mean, h_var = forward_linear_diag_determinstic_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.MLP.c_fc.bias.data)
            else:
                h_mean, h_var = forward_linear_diag_Bayesian_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.w_fc_var, self.MLP.c_fc.bias.data)
        # activation function
        h_mean, h_var = forward_activation_diag(self.MLP.gelu, h_mean, h_var)
        # second fc layer
        with torch.no_grad():
            if self.determinstic:
                h_mean, h_var = forward_linear_diag_determinstic_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.MLP.c_proj.bias.data)
            else:
                h_mean, h_var = forward_linear_diag_Bayesian_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.w_proj_var, self.MLP.c_proj.bias.data)

        return h_mean, h_var

class Attention_DBNN_Diag(nn.Module):
    
    def __init__(self, Attention, determinstic = True, W_v_var = None):
        super().__init__()
        
        self.Attention = Attention
        self.determinstic = determinstic
        
        if not self.determinstic:
            self.W_v_var = W_v_var # [D * D]

    def forward(self, x_mean, x_var):
        
        with torch.no_grad():
        
            output_mean, attention_score = self.Attention.forward(x_mean, True)
            
            n_h = self.Attention.n_head
            B, T, D = x_mean.size()
            D_v = D // n_h
            
            W_v = self.Attention.c_attn_v.weight.data
            project_W = self.Attention.c_proj.weight.data

            if self.determinstic:
                v_cov = forward_value_cov_determinstic_W(W_v, x_var, n_h, D_v)
            else:
                v_cov = forward_value_cov_Bayesian_W(W_v, self.W_v_var.reshape(D, D), x_mean, x_var, n_h, D_v)

            QKV_cov = forward_QKV_cov(attention_score, v_cov)

            output_var = forward_fuse_multi_head_cov(QKV_cov, project_W)

            
            del v_cov
            del QKV_cov
            torch.cuda.empty_cache()

            
            return output_mean, output_var
        
class Transformer_Block_DBNN_Diag(nn.Module):
    
    def __init__(self, MLP, Attention, LN_1, LN_2, MLP_determinstic, Attn_determinstic, w_fc_var = None, w_proj_var = None, W_v_var = None):
        super().__init__()
        
        self.ln_1 = LayerNorm_DBNN_Diag(LN_1)
        self.ln_2 = LayerNorm_DBNN_Diag(LN_2)
        self.attn = Attention_DBNN_Diag(Attention, Attn_determinstic, W_v_var)
        self.mlp = MLP_DBNN_Diag(MLP, MLP_determinstic, w_fc_var, w_proj_var)
    
    def forward(self, x_mean, x_var):
        
        h_mean, h_var = self.ln_1(x_mean, x_var)
        h_mean, h_var = self.attn(h_mean, h_var)
        h_mean = h_mean + x_mean
        h_var = h_var + x_var
        
        old_h_mean, old_h_var = h_mean, h_var
        
        h_mean, h_var = self.ln_2(h_mean, h_var)
        h_mean, h_var = self.mlp(h_mean, h_var)
        h_mean = h_mean + old_h_mean
        h_var = h_var + old_h_var
        
        return h_mean, h_var

class ViT_DBNN_Diag(nn.Module):
    
    def __init__(self, ViT, posterior_variance, scale_factor, MLP_determinstic, Attn_determinstic, alpha = 1., num_det_blocks = 10):
        super().__init__()
    
        self.transformer = nn.ModuleDict(dict(
            pte = ViT.transformer.pte,
            h = nn.ModuleList(),
            ln_f = LayerNorm_DBNN_Diag(ViT.transformer.ln_f)
        ))
        
        self.scale_factor = nn.Parameter(torch.Tensor([scale_factor]).to(ViT.device))
        self.alpha = alpha
        
        num_param_c_fc = ViT.transformer.h[0].mlp.c_fc.weight.numel()
        num_param_c_proj = ViT.transformer.h[0].mlp.c_proj.weight.numel()
        num_param_value_matrix = ViT.transformer.h[0].attn.c_proj.weight.numel()
        
        index = 0
        for block_index in range(len(ViT.transformer.h)):
            
            if block_index < num_det_blocks:
                self.transformer.h.append(ViT.transformer.h[block_index])
            else:
                if not MLP_determinstic:
                    w_fc_var = posterior_variance[index: index + num_param_c_fc]
                    index += num_param_c_fc
                    w_proj_var = posterior_variance[index: index + num_param_c_proj]
                    index += num_param_c_proj
                    self.transformer.h.append(
                        Transformer_Block_DBNN_Diag(ViT.transformer.h[block_index].mlp, 
                                                    ViT.transformer.h[block_index].attn, 
                                                    ViT.transformer.h[block_index].ln_1, 
                                                    ViT.transformer.h[block_index].ln_2, 
                                                    MLP_determinstic,
                                                    Attn_determinstic,
                                                    w_fc_var, 
                                                    w_proj_var,
                                                    None))
                
                if not Attn_determinstic:
                    w_v_var = posterior_variance[index : index + num_param_value_matrix]
                    index += num_param_value_matrix
                    self.transformer.h.append(
                        Transformer_Block_DBNN_Diag(ViT.transformer.h[block_index].mlp, 
                                                    ViT.transformer.h[block_index].attn, 
                                                    ViT.transformer.h[block_index].ln_1, 
                                                    ViT.transformer.h[block_index].ln_2, 
                                                    MLP_determinstic,
                                                    Attn_determinstic,
                                                    None, 
                                                    None,
                                                    w_v_var))

        num_param_classifier_weight = ViT.classifier.weight.numel()
        self.classifier = Classifier_DBNN_Diag(ViT.classifier, posterior_variance[index: index + num_param_classifier_weight], posterior_variance[index + num_param_classifier_weight:])
    
    def forward(self, pixel_values, interpolate_pos_encoding = None):
        device = pixel_values.device

        x_mean = self.transformer.pte(
            pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
        )
        
        # pass through model
        x_var = torch.zeros_like(x_mean, device = device)
        
        for i, block in enumerate(self.transformer.h):
            
            if isinstance(block, Transformer_Block_DBNN_Diag):
                # x_mean = x_mean + block.attn(block.ln_1(x_mean))

                # old_x_mean = x_mean
                # x_mean = block.ln_2(x_mean)
                # x_mean, x_var = block.mlp(x_mean, x_var)
                
                # x_mean = x_mean + old_x_mean
                
                x_mean, x_var = block(x_mean, x_var)
                
            else:
                x_mean = block(x_mean)
        
        x_mean, x_var = self.transformer.ln_f(x_mean, x_var)
        
        x_mean, x_var = self.classifier(x_mean[:, 0, :], x_var[:, 0, :])
        
        x_var = x_var / (self.alpha * self.scale_factor)
        kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)
        
        return torch.softmax(kappa * x_mean, dim=-1)

    def forward_latent(self, pixel_values, interpolate_pos_encoding = None):
        device = pixel_values.device

        x_mean = self.transformer.pte(
            pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
        )
        
        # pass through model
        x_var = torch.zeros_like(x_mean, device = device)
        
        for i, block in enumerate(self.transformer.h):
            
            if isinstance(block, Transformer_Block_DBNN_Diag):
            
                x_mean, x_var = block(x_mean, x_var)
            else:
                x_mean = block(x_mean)

        x_mean, x_var = self.transformer.ln_f(x_mean, x_var)
        
        x_mean, x_var = self.classifier(x_mean[:, 0, :], x_var[:, 0, :])
        x_var = x_var / (self.alpha * self.scale_factor)
        
        return x_mean, x_var

    def fit_scale_factor(self, scale_fit_dataloader, n_epoches, lr, verbose = False):
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        optimizer = torch.optim.Adam([self.scale_factor], lr)
        
        print("fitting scale...")
        
        total_train_nlpd = []

        for epoch in range(n_epoches):
            cur_nlpd = []
            for data_pair in tqdm(scale_fit_dataloader):
                x_mean, x_var_label = data_pair
                num_class = x_mean.shape[1]
                x_mean = x_mean.to(device)
                x_var, label = x_var_label.split(num_class, dim=1)
                x_var = x_var.to(device)
                label = label.to(device)

                optimizer.zero_grad()
                # make prediction
                x_var = x_var / (self.alpha * self.scale_factor)
                kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)
                posterior_predict_mean = torch.softmax(kappa * x_mean, dim=-1)
                
                # construct log posterior predictive distribution
                posterior_predictive_dist = Categorical(posterior_predict_mean)
                # calculate nlpd and update
                nlpd = -posterior_predictive_dist.log_prob(label.argmax(1)).mean()
                nlpd.backward()
                optimizer.step()
                # log nlpd
                cur_nlpd.append(nlpd.item())
            
            total_train_nlpd.append(np.mean(cur_nlpd))
            if verbose:
                print(f"Epoch {epoch}: {np.mean(cur_nlpd)}")

        return total_train_nlpd