import torch
import torch.nn as nn
import numpy as np
from .stg import StochasticGates, StochasticGates2
from .utils import mySequential, PDF


class SparseDeepMTE(nn.Module):
    def __init__(self, x_dim, y_dim, xynet, lamx, lamy, lamdis, out_dim, gate_init1,gate_init2,sigmax=1, sigmay=1):
        """
        :param x_dim: Dx
        :param y_dim: Dy
        :param xy_net: output X_hat, Y_hat, WY_hat
        :param lamx: regularizer for x gates
        :param lamy: regularizer for y gates
        :param lamdis: regularizer for minimizing discrepancy of the input to Gaussian distribution
        :param gate_init1: Initialization for gate 1
        :param gate_init2: Initialization for gate 2
        :param sigmax: std for x gates
        :param sigmay: std for y gates
        """
        super().__init__()
        self.f = self._create_network(x_dim, xynet, lamx, lamy, sigmax, sigmay,gate_init1,gate_init2)

        self.lamdis = lamdis
        self.g = PDF(3 * out_dim,'gauss')

    def forward(self, X, L):

        X_hat, Y_hat, Wy_hat = self.f(X,L)

        return  -self._get_corr(X_hat, Y_hat, Wy_hat) + self.f[0].get_reg() + self.f[1].get_reg() + self.lamdis * self.g(torch.cat((X_hat,Y_hat,Wy_hat),1))
    def get_gates(self):
        """
        use this function to retrieve the gates values for each modality
        :return: gates values
        """
        return self.f[0].get_gates() * self.f[1].get_gates(), self.f[0].get_gates() * (1-self.f[1].get_gates())

    def get_function_parameters(self):
        """
        use this function if you wish to use a different optimizer for functions and gates
        :return: learnable parameters of f and g
        """
        params = list()
        for net in [self.f]:
            params += list(net.parameters())
        return params

    def get_gates_parameters(self):
        """
        use this function if you wish to use a different optimizer for functions and gates
        :return: learnable parameters of the gates
        """
        params = list()
        for net in [self.f]:
            params += list(net[0].parameters())
        return params

    def get_components(self,X,L):
        return self.f(X,L)


    @staticmethod
    def _create_network(in_features, xynet, lam1, lam2, sigma1, sigma2,gate_init1,gate_init2):

        return mySequential(StochasticGates(in_features, sigma1, lam1,gate_init=gate_init1),StochasticGates2(in_features, sigma2, lam2,gate_init=gate_init2),xynet)

    def _get_corr(self, X, Y, WY):

        psi_x = X - X.mean(axis=0)
        psi_y_ = WY - WY.mean(axis=0)

        C_yy_ = self._cov(psi_y_, psi_y_)
        C_yx_ = self._cov(psi_y_, psi_x)
        C_xy_ = self._cov(psi_x, psi_y_)
        C_xx_ = self._cov(psi_x, psi_x)

        # Assume C_xx is invertible, then there is only two term in MI by the determinant of D - CA^{-1}B and D

        #C_yy_inv_root_ = self._mat_to_the_power(C_yy_+torch.eye(C_yy_.shape[0], device=Y.device)*1e-3, -0.5)
        C_xx_inv_ = torch.inverse(C_xx_+torch.eye(C_xx_.shape[0], device=X.device)*1e-3)
        #M1 = torch.linalg.multi_dot([C_yy_inv_root_, C_yx_, C_xx_inv_, C_xy_, C_yy_inv_root_])

        l1 = -torch.logdet(C_yy_-torch.linalg.multi_dot([C_yx_,C_xx_inv_,C_xy_ ])) + torch.logdet(C_yy_)

        psi_y = Y - Y.mean(axis=0)

        C_yy = self._cov(psi_y, psi_y)
        C_yx = self._cov(psi_y, psi_x)
        C_xy = self._cov(psi_x, psi_y)
        C_xx = self._cov(psi_x, psi_x)

        C_xx_inv = torch.inverse(C_xx+torch.eye(C_xx.shape[0], device=X.device)*1e-3)

        l2 = -torch.logdet(C_yy-torch.linalg.multi_dot([C_yx,C_xx_inv,C_xy])) + torch.logdet(C_yy)

        return l1 - l2

    #def _laplacian_score(self, X, L):
        """
        computes spatial laplacian score for encouraging smooth structure on spatial domain
        """
    #    psi_x = X - X.mean(axis=0)
    #    psi_y = L @ psi_x
    #    fLf= self._cov(psi_x, psi_y)
    #    ff = self._cov(psi_x, psi_x)
    #    m = fLf @ torch.inverse(torch.diag(torch.diag(ff))+torch.eye(ff.shape[0])*1e-6)
    #    return torch.trace(m) / X.shape[1]


    @staticmethod
    def _cov(psi_x, psi_y):
        """
        :return: covariance matrix
        """
        N = psi_x.shape[0]
        return (psi_y.T @ psi_x).T / (N - 1)

