from models_utility.function_gp import lt_log_determinant
from torch import triangular_solve
from sklearn.decomposition import PCA
import numpy as np
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
from torch.distributions import kl_divergence
from torch.nn import functional as F
import gpytorch

torch.set_default_tensor_type(torch.DoubleTensor)

zitter = 1e-8


class NG_GPLVM(nn.Module):
    def __init__(self, num_batch, num_sample_pt, param_dict, Y, device=None, ifPCA=True):
        super(NG_GPLVM, self).__init__()
        self.device = device
        self.name = None
        self.num_batch = num_batch
        self.num_samplept = num_sample_pt  # L/2 the number of spectral points
        self.latent_dim = param_dict['latent_dim']  # Q
        self.N = param_dict['N']                    # !!!
        self.num_m = param_dict['num_m']            # m orginal dimension
        self.noise = param_dict['noise_err']
        self.lr_hyp = param_dict['lr_hyp']
        self.Y = Y
        self.total_num_sample = self.num_samplept * self.num_m  # m * L/2
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()

        # shape: m * 1
        self.log_weight = nn.Parameter(torch.randn(self.num_m, 1, device=self.device), requires_grad=True)

        if self.num_m==1:
            # if SE kernel is used, then self.mu = 0, and requires_grad=False
            self.mu1 = nn.Parameter(torch.zeros(self.num_m, self.latent_dim, device=self.device), requires_grad=False)
            self.mu2 = nn.Parameter(torch.zeros(self.num_m, self.latent_dim, device=self.device), requires_grad=False)# shape: m * Q
        else:
            self.mu1 = nn.Parameter(torch.zeros(self.num_m, self.latent_dim, device=self.device), requires_grad=True)
            self.mu2 = nn.Parameter(torch.zeros(self.num_m, self.latent_dim, device=self.device), requires_grad=True)# shape: m * Q

        self.log_std = nn.Parameter(torch.randn(self.num_m, 2 * self.latent_dim, 2 * self.latent_dim, device=self.device), requires_grad=True)  # shape: m * 2Q * 2Q

        if ifPCA:
            pca = PCA(n_components=self.latent_dim)
            X = pca.fit_transform(self.Y)
        else:
            X = torch.randn(self.N, self.latent_dim, device=self.device)
        self.mu_x = nn.Parameter(torch.tensor(X, device=self.device), requires_grad=True)    # shape: N * Q
        self.log_sigma_x = nn.Parameter(torch.zeros(self.N, self.latent_dim, device=self.device), requires_grad=True)


    def _compute_sm_basis(self, x_star=None, f_eval=False):
        multiple_Phi = []
        current_sampled_spectral_list = []

        if f_eval:  # use to evaluate the latent function 'f'
            x = self.mu_x
        else:
            std = F.softplus(self.log_sigma_x)   # shape: N * Q
            eps = torch.randn_like(std)          # don't preselect/prefix it in __init__ function
            x = self.mu_x + eps * std
        # sample form q(x)
        SM_weight = F.softplus(self.log_weight) #alpha m*1
        SM_std = F.softplus(self.log_std) #m * 2Q * 2Q


        for i_th in range(self.num_m):  # TODO: check if it can be improved without using for
            SM_eps = torch.randn(self.num_samplept, 2 * self.latent_dim, device=self.device) # L/2 * 2Q
            std = SM_std[i_th].t().matmul(SM_std[i_th])
            L1 = torch.linalg.cholesky(std) # 2Q * 2Q
            sampled_spectral_pt = torch.cat((self.mu1[i_th],self.mu2[i_th]),dim =0) + SM_eps.matmul(L1)  #  L/2 * 2Q
            sampled_spectral_pt1 =  sampled_spectral_pt[:, :self.latent_dim]
            sampled_spectral_pt2 = sampled_spectral_pt[:, -self.latent_dim:]

            ''' two step reparameterization trick
            SM_eps1 = torch.randn(self.num_samplept, self.latent_dim, device=self.device) # L/2 * 2Q
            SM_eps2 = torch.randn(self.num_samplept, self.latent_dim, device=self.device)
            sampled_spectral_pt1 = self.mu1[view][i_th] + SM_std1[i_th] * SM_eps1
            sampled_spectral_pt2 = self.mu2[view][i_th] + rho[i_th] * (SM_std2[i_th] / SM_std1[i_th]) * (
                        sampled_spectral_pt1 - self.mu1[view][i_th]) + torch.sqrt(1 - rho[i_th] ** 2) * SM_std1[i_th] * SM_eps2
            '''

            if x_star is not None:
              current_sampled_spectral_list.append(sampled_spectral_pt)

            x_spectral1 = (2 * np.pi) * x.matmul(sampled_spectral_pt1.t())    # N * L/2
            x_spectral2 = (2 * np.pi) * x.matmul(sampled_spectral_pt2.t())  # N * L/2

            Phi_i_th = (SM_weight[i_th] / (4 * self.num_samplept)).sqrt() * torch.cat([x_spectral1.cos()+x_spectral2.cos(), x_spectral1.sin()+ x_spectral1.sin()], 1)

            multiple_Phi.append(Phi_i_th)

        if x_star is None:
            return torch.cat(multiple_Phi, 1)  #  N * (m * L）
        else:
            multiple_Phi_star = []
            for i_th, current_sampled in enumerate(current_sampled_spectral_list):
                current_sampled1 = current_sampled[:, :self.latent_dim]
                current_sampled2 = current_sampled[:, -self.latent_dim:]
                xstar_spectral1 = (2 * np.pi) * x.matmul(current_sampled1.t())  # N * L/2
                xstar_spectral2 = (2 * np.pi) * x.matmul(current_sampled2.t())  # N * L/2
                Phistar_i_th = (SM_weight[i_th] / (4 * self.num_samplept)).sqrt() * torch.cat([xstar_spectral1.cos()+xstar_spectral2.cos(), xstar_spectral1.sin()+ xstar_spectral1.sin()], 1)
                multiple_Phi_star.append(Phistar_i_th)
            return torch.cat(multiple_Phi, 1), torch.cat(multiple_Phi_star, 1)  #  N * (m * 2 L/2),  N_star * (M * L)


    def _compute_gram_approximate(self, Phi):  # shape:  (m*L) x (m*L)
        return Phi.t() @ Phi + (self.likelihood.noise + zitter).expand(Phi.shape[1], Phi.shape[1]).diag().diag()


    def _compute_gram_approximate_2(self, Phi):  # shape:  N x N
        return Phi @ Phi.T


    def _kl_div_qp(self):

        # shape: N x Q
        q_dist = torch.distributions.Normal(loc=self.mu_x, scale= F.softplus(self.log_sigma_x))
        p_dist = torch.distributions.Normal(loc=torch.zeros_like(q_dist.loc), scale=torch.ones_like(q_dist.loc))

        return kl_divergence(q_dist, p_dist).sum().div(self.N * self.latent_dim)

    def compute_loss(self, batch_y, kl_option):
        """
        :param batch_y:
        :return: approximate lower bound of negative log marginal likelihood
        """
        obs_dim = batch_y.shape[1]
        obs_num = batch_y.shape[0]
        batch_y = torch.tensor(batch_y, device=self.device, dtype=torch.double)
        Phi = self._compute_sm_basis()

        # negative log-marginal likelihood
        if Phi.shape[0]>Phi.shape[1]:  # if N > (m*L)
            Approximate_gram = self._compute_gram_approximate(Phi)  # shape:  (m * L） x  (m * L）
            L = torch.cholesky(Approximate_gram)
            Lt_inv_Phi_y = triangular_solve((Phi.t()).matmul(batch_y), L, upper=False)[0]

            # todo: need to double-check this part
            neg_log_likelihood = (0.5 / self.likelihood.noise) * (batch_y.pow(2).sum() - Lt_inv_Phi_y.pow(2).sum())
            neg_log_likelihood += lt_log_determinant(L)
            neg_log_likelihood += (-self.total_num_sample) * 2 * self.likelihood.noise.sqrt()
            neg_log_likelihood += 0.5 * obs_num * (np.log(2 * np.pi) + 2 * self.likelihood.noise.sqrt())

        else:
            k_matrix = self._compute_gram_approximate_2(Phi=Phi) # shape: N x N
            C_matrix = k_matrix + self.likelihood.noise * torch.eye(self.N, device=self.device)
            L = torch.cholesky(C_matrix) # shape: N x N
            L_inv_y = triangular_solve(batch_y, L, upper=False)[0]
            constant_term = 0.5 * obs_num * np.log(2 * np.pi) * obs_dim
            log_det_term = torch.diagonal(L, dim1=-2, dim2=-1).sum().log() * obs_dim
            yy = 0.5 * L_inv_y.pow(2).sum()
            neg_log_likelihood = (constant_term + log_det_term + yy).div(obs_dim * obs_num)

        if kl_option:
            kl_x = self._kl_div_qp().div(self.N * 50)
            loss_all = neg_log_likelihood + kl_x
        else:
            loss_all = neg_log_likelihood

        return loss_all


    def f_eval(self, batch_y, x_star=None):
        """
            evaluation of the latent mapping function

            x_star:         prediction input                            shape: N_star * Q
            batch_y:        observations for characterizing the GP      shape: N * obs_dim
        """
        batch_y = torch.tensor(batch_y, device=self.device, dtype=torch.double)

        if x_star is None:
            x_star = self.mu_x

        Phi, Phi_star = self._compute_sm_basis(x_star=x_star, f_eval=True)

        cross_matrix = Phi_star @ Phi.T                                  # shape: N_star * N

        k_matrix = self._compute_gram_approximate_2(Phi=Phi)             # shape: N * N
        C_matrix = k_matrix + self.likelihood.noise * torch.eye(self.N, device=self.device)

        L = torch.cholesky(C_matrix)                                    # shape: N x N
        L_inv_y = triangular_solve(batch_y, L, upper=False)[0]          # inv(L) * y
        K_L_inv = triangular_solve(cross_matrix.T, L, upper=False)[0]   # inv(L) * K_{N, N_star}

        f_star = K_L_inv.T @ L_inv_y                          # shape: N_star * obs_dim

        return f_star, k_matrix


def ssgpr_sm():
    return None

if __name__ == "__main__":
    Y = torch.tensor(np.arange(-5, 5, 0.1).reshape(-1, 5))
    setting_dict = {}
    setting_dict['num_m'] = 2  # if num_m = 1, it is using SE kernel
    setting_dict['num_sample_pt'] = 50
    setting_dict['num_total_pt'] = setting_dict['num_m'] * setting_dict['num_sample_pt']
    setting_dict['num_batch'] = 1
    setting_dict['lr_hyp'] = .01
    setting_dict['iter'] = 100
    setting_dict['num_repexp'] = 1
    setting_dict['kl_option'] = True  # if adding X regularization in loss function
    setting_dict['noise_err'] = 100.0
    setting_dict['latent_dim'] = 2
    setting_dict['N'] = Y.shape[0]
    setting_dict['noise_err'] = .05 * Y.std()

    GPLVM_model = NSRFF_GPLVM(setting_dict['num_batch'],
                            setting_dict['num_sample_pt'],
                            setting_dict,
                            Y,
                            device="cpu").to("cpu")
    print(GPLVM_model._compute_sm_basis())