import torch.nn as nn
import torch.nn.functional as F
import torch
import math
import numpy as np
import copy
from torch.distributions import Categorical
from tqdm import tqdm

from model.dbnn_full import DBNN_Activation
from model.dbnn_base import DBNN_Base

def kron_block_index(D_in, D_out, k, l):
    """
    Given D_in, D_out, return the kron block index that contains cov(W_k, W_l), and where the covariance of interests starts
    """
    
    kron_row_start = math.floor((k * D_in) / D_out) # if (k * D_in) % D_out else (k * D_in) / D_out - 1 # when divisible python index need to subtract one
    kron_col_start = math.floor((l * D_in) / D_out) # if (l * D_in) % D_out else (l * D_in) / D_out - 1
    
    kron_row_end = math.ceil(((k+1) * D_in) / D_out) 
    kron_col_end = math.ceil(((l+1) * D_in) / D_out)
    
    select_row_start = (k * D_in) % D_out
    select_col_start = (l * D_in) % D_out
    
    return kron_row_start, kron_row_end, kron_col_start, kron_col_end, select_row_start, select_col_start

def kron(A, B):
    # https://github.com/pytorch/pytorch/issues/74442
    return (A[:, None, :, None] * B[None, :, None, :]).reshape(A.shape[0] * B.shape[0], A.shape[1] * B.shape[1])

def get_covariance(D_in, D_out, k, l, A, B):
    """
    Given the posterior covariance stored in Kron factors form, return cov(w_k, w_l)
    
    """
    
    kron_row_start, kron_row_end, kron_col_start, kron_col_end, select_row_start, select_col_start = kron_block_index(D_in, D_out, k, l)
    
    kron_select_index = A[kron_row_start: kron_row_end, kron_col_start : kron_col_end]
    selected_kron_block = kron(kron_select_index, B)
    selected_covariance = selected_kron_block[select_row_start : select_row_start + D_in, select_col_start : select_col_start + D_in]
    
    return selected_covariance

def forward_aW_kron(a_mean, a_cov, weight, bias, w_cov_A, w_cov_B, b_cov):
    """
    compute mean and covariance of h = a @ W^T + b when posterior has Kronecker factored covariance 
    
    ----- Input -----
    a_mean: [N, D_in] mean(a)
    a_cov: [N, D_in, D_in] a_cov[i][j] = cov(a_i, a_j)
    weight: [D_out, D_in] W
    bias: [D_out, ] b
    w_cov_A: [D_in, D_in] 
    w_cov_B: [D_out, D_out] we have cov(W) = w_cov_A kron_product w_cov_B
    b_cov: [D_out, D_out] b_cov[k][l]: cov(b_k, b_l)
    ----- Output -----
    h_mean: [N, D_out]
    h_cov: [N, D_out, D_out] h_cov[k][l] = cov(h_k, h_l)
    """
    
    D_out, D_in = weight.shape
    N = a_mean.shape[0]
    
    # calculate mean(h)
    h_mean = F.linear(a_mean, weight, bias)
    
    # calculate cov(h)    
    A_matrix = torch.bmm(a_mean.unsqueeze(2), a_mean.unsqueeze(1))
        # [N, Din] -> [N, Din, 1]
        # [N, Din] -> [N, 1, Din]
        # [N, Din, 1] @ [N, 1, Din] -> [N, Din, Din]
    A_matrix += a_cov
    
    h_cov = torch.zeros((N, D_out, D_out), device=weight.device)
    cal_cov_implicit(h_cov, A_matrix, a_cov, D_in, D_out, w_cov_A, w_cov_B, b_cov, weight)

    return h_mean, h_cov

def cal_cov_implicit(h_cov, A_matrix, a_cov, D_in, D_out, w_cov_A, w_cov_B, b_cov, weight):
    
    for k in range(D_out):
        for l in range(k, D_out):
            cur_cov = get_covariance(D_in, D_out, k, l, w_cov_A, w_cov_B)
            cur_weight_matrix = weight[k].reshape(-1,1) @ weight[l].reshape(1, -1)
            h_cov[:,k,l] = torch.sum(A_matrix * cur_cov + a_cov * cur_weight_matrix, dim=(2,1))
            h_cov[:,l,k] = h_cov[:,k,l]
        
    h_cov += b_cov

class DBNN_Linear_Kron(nn.Module):
    def __init__(self, org_linear, w_cov_A, w_cov_B, b_cov):
        super().__init__()
        
        self.weight = org_linear.weight.data
        self.bias = org_linear.bias.data
        self.w_cov_A = w_cov_A
        self.w_cov_B = w_cov_B
        self.b_cov = b_cov
    
    def forward(self, a_mean, a_cov): 
        
        if a_cov == None:
            a_cov = torch.zeros((a_mean.shape[0], a_mean.shape[1], a_mean.shape[1]), device = a_mean.device)

        h_mean, h_cov = forward_aW_kron(a_mean, a_cov, self.weight, self.bias, self.w_cov_A, self.w_cov_B, self.b_cov)
        
        return h_mean, h_cov
    
    def reconstruct_weight_covariance(self):
        return kron(self.w_cov_A, self.w_cov_B)

class DBNN_MLP_Kron(DBNN_Base):
    def __init__(self, org_model, hessian_eigenvector, hessian_eigenvalue, hessian_kfac, likelihood, scale_init, prior_precision, sigma_noise = None):
        super().__init__(likelihood, scale_init)
        """
        only take in the model that corresponds to the latent function part (in regression case it's the whole network, in classification case remove the softmax layer)
        """
        self.sigma_noise = sigma_noise
        self.convert_model(org_model, hessian_eigenvector, hessian_eigenvalue, hessian_kfac, prior_precision)
    
    def forward(self, data):
        
        out_mean, out_cov = self.forward_latent(data)

        if self.likelihood == 'classification':
            kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_cov.diagonal(dim1=1, dim2=2))
            return torch.softmax(kappa * out_mean, dim=-1)

        if self.likelihood == 'regression':
            return out_mean, torch.diagonal(out_cov, dim1=1, dim2=2) + self.sigma_noise ** 2
    
    def convert_model(self, org_model, hessian_eigenvector, hessian_eigenvalue, hessian_kfac, prior_precision):
        
        p_model = copy.deepcopy(org_model)
        prior_precision_sqrt = np.sqrt(prior_precision)
        
        loc = 0
        for n, layer in p_model.named_modules():
            if isinstance(layer, nn.Linear):
                
                Q_B, Q_A = hessian_eigenvector[loc]
                l_B = copy.deepcopy(hessian_eigenvalue[loc][0])
                l_A = copy.deepcopy(hessian_eigenvalue[loc][1])
                
                B, A = hessian_kfac[loc]
                
                l_A += prior_precision_sqrt
                l_B += prior_precision_sqrt

                w_cov_A = torch.linalg.inv(Q_A @ torch.diag(l_A) @ Q_A.T)
                w_cov_B = torch.linalg.inv(Q_B @ torch.diag(l_B) @ Q_B.T)

                b_cov = torch.linalg.inv(B + prior_precision * torch.eye(l_B.shape[0], device = l_B.device))
        
                new_layer = DBNN_Linear_Kron(layer, w_cov_A, w_cov_B, b_cov)
                
                loc += 2

                setattr(p_model, n, new_layer)

            if type(layer).__name__ in torch.nn.modules.activation.__all__:
                new_layer = DBNN_Activation(layer)
                setattr(p_model, n, new_layer)

        self.model = p_model

    def fit_scale_factor(self, scale_fit_dataloader, n_epoches, lr, verbose = False):

        optimizer = torch.optim.Adam([self.scale_factor], lr)
        
        print("fitting scale...")
        
        total_train_nlpd = []

        for epoch in range(n_epoches):
            cur_nlpd = []
            for data_pair in tqdm(scale_fit_dataloader):
                x_mean, x_var_label = data_pair
                num_class = x_mean.shape[1]
                x_mean = x_mean.to(self.model.device)
                x_var, label = x_var_label.split(num_class, dim=1)
                x_var = x_var.to(self.model.device)
                label = label.to(self.model.device)

                optimizer.zero_grad()
                # make prediction
                x_var = x_var / (self.alpha * self.scale_factor)
                kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)
                posterior_predict_mean = torch.softmax(kappa * x_mean, dim=-1)
                
                # construct log posterior predictive distribution
                posterior_predictive_dist = Categorical(posterior_predict_mean)
                # calculate nlpd and update
                nlpd = -posterior_predictive_dist.log_prob(label.argmax(1)).mean()
                nlpd.backward()
                optimizer.step()
                # log nlpd
                cur_nlpd.append(nlpd.item())
            
            total_train_nlpd.append(np.mean(cur_nlpd))
            if verbose:
                print(f"Epoch {epoch}: {np.mean(cur_nlpd)}")

        return total_train_nlpd