import torch.nn
from utils import *
from numpy import shape
from numpy.linalg import eigh

class PowerKCI(object):
    def __init__(self, z_tr, z_te, x_tr, x_te, y_tr, y_te,
                 test_method = None, null_samples = 1000, thresh_hold = 1e-5):
        self.null_samples = null_samples # null sample size
        self.thresh = thresh_hold # SVD threshold
        self.test_method = test_method # test method to get p-value: chi_square or gamma

        # parameters in learning kz
        self.lr = 0.1 # learning rate in Adam
        self.iteration = 100 # learning iteration
        self.lamb = 1e-10 # lambda in variance

        self.x_tr, self.y_tr, self.z_tr = totensor(x_tr, y_tr, z_tr)
        self.x_te, self.y_te, self.z_te = totensor(x_te, y_te, z_te)
        self.Dz = self.z_tr.shape[-1]

        # data initialization using median heuristic
        self.xlength_init = set_median_width(self.x_tr)
        self.ylength_init = set_median_width(self.y_tr)
        self.zlength_med =  set_median_width(self.z_tr)
        self.zlength_init_func(self.zlength_med)

    def zlength_init_func(self, zlength_init):
        zlength_med_vec = torch.tensor(zlength_init * np.ones(self.Dz)).detach()
        self.zlength = torch.nn.Parameter(zlength_med_vec)
        self.zlength.requires_grad = True

    def compute_pvalue(self):
        KR_xtr, KR_xte = self.regression(self.x_tr, self.x_te, self.xlength_init)
        KR_ytr, KR_yte = self.regression(self.y_tr, self.y_te, self.ylength_init)

        # learning kernels of Kz via power
        Kz_te = self.SelectingKz(KR_xtr, KR_ytr)
        power_pvalue, power_stat = self.cal_pvalue(KR_xte.detach(), KR_yte.detach(), Kz_te.detach())

        # median-heuristic
        Kz_te_med = self.kernelz(self.z_te, self.zlength_med).detach()
        median_pvalue, median_stat = self.cal_pvalue(KR_xte.clone(), KR_yte.clone(), Kz_te_med)
        return power_pvalue, median_pvalue


    def SelectingKz(self, KRx, KRy):
        self.zlength_init_func(self.zlength_med)
        optim = torch.optim.Adam([self.zlength], lr=self.lr)
        for i in range(self.iteration):
            Kz_tr = self.kernelz(self.z_tr, self.zlength)
            est_negJ = -self.power(KRx, KRy, Kz_tr)
            optim.zero_grad()
            est_negJ.backward()
            optim.step()

            zlength = self.zlength.data
            zlength = zlength.clamp_(1e-2, 20)
            self.zlength.data = zlength

        Kz_te = self.kernelz(self.z_te, self.zlength)
        return Kz_te

    def cal_pvalue(self, KxR, KyR, Kz):
        KxR = KxR * Kz
        KyR = self.kernel_centering(KyR)
        test_stat = torch.sum(KxR*KyR).detach().numpy()

        uu_prod, size_u = self.get_uuprod(KxR, KyR)
        if self.test_method == 'gamma':
            k_appr, theta_appr = self.get_kappa(uu_prod)
            pvalue = 1 - stats.gamma.cdf(test_stat, k_appr, 0, theta_appr)
        elif self.test_method == 'chi_square':
            null_samples = self.null_sample_spectral(uu_prod, size_u, KxR.shape[0])
            pvalue = sum(null_samples > test_stat) / float(self.null_samples)
        else:
            raise NotImplementedError('test method not implemented')

        return pvalue, test_stat

    def power(self, Kx, Ky, Kz):
        n = Kx.shape[0]
        S = Kx*Ky*Kz
        S = self.diag_zero(S)
        KCIu = torch.sum(S)
        Sj = S.sum(0) / (n-1)
        sigma1 = torch.sqrt(torch.sum((Sj - KCIu)**2) / n + 1e-10)
        J = KCIu / 2*sigma1
        return J

    def kernelz(self, z, zlength):
        z = z / zlength
        zsq = (z ** 2).sum(dim=1, keepdim=True)
        sqdist = zsq + zsq.T - 2 * z.mm(z.T)
        Kz =  torch.exp(- 0.5 * sqdist)
        return Kz


    def diag_zero(self, K):
        diag_vec = K.diag()
        diag_mat = torch.diag_embed(diag_vec)
        return K - diag_mat

    def regression(self, x_tr, x_te, xlength):
        Kx_tr = cal_kernel(x_tr, xlength)
        Kx_tr = self.kernel_centering(Kx_tr)

        phi_x = reduce_func(Kx_tr, self.thresh)
        phi_x = torch.from_numpy(phi_x)
        gpx = self.gp(self.zlength_med)
        gpx.fit(self.z_tr.numpy(), phi_x)

        #residual
        KR_tr = self.residual(gpx, self.z_tr, x_tr, xlength)
        KR_te = self.residual(gpx, self.z_te, x_te, xlength)
        return KR_tr, KR_te

    def residual(self, gp, z, x, xlength):
        # np parameters
        noise_scale = np.exp(gp.kernel_.theta[-1])
        Kz = gp.kernel_.k1(z)
        n = shape(Kz)[0]

        # torch parameters
        noise_scale_t = torch.tensor(noise_scale)
        Kzx_t = torch.from_numpy(Kz)
        Kxc_t = cal_kernel(x, xlength)

        Rz = noise_scale_t * torch.linalg.pinv(Kzx_t + noise_scale_t * torch.eye(n))
        KR = Rz.matmul(Kxc_t.matmul(Rz))
        return KR

    def gp(self, zlength_init):
        from sklearn.gaussian_process import GaussianProcessRegressor
        from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel
        kernelx = (ConstantKernel(1.0, (1e-3, 1e3))
                             * RBF(zlength_init * np.ones(self.Dz), (1e-2, 1e2))
                       + WhiteKernel(0.1, (1e-10, 1e+1)))
        gpx = GaussianProcessRegressor(kernel=kernelx)
        return gpx

    def regression_residual(self, Kx, Kz, epsilon):
        n = Kx.shape[0]
        Rz = epsilon*torch.linalg.inv(Kz + epsilon * torch.eye(n))
        return Rz.matmul(Kx.matmul(Rz))

    @staticmethod
    def kernel_centering(K):
        n = shape(K)[0]
        K_colsums = K.sum(axis=0)
        K_allsum = K_colsums.sum()
        return K - (K_colsums[None, :] + K_colsums[:, None]) / n + (K_allsum / n ** 2)

    def get_uuprod(self, Kx, Ky):
        wx, vx = eigh(0.5 * (Kx + Kx.T))
        wy, vy = eigh(0.5 * (Ky + Ky.T))
        idx = np.argsort(-wx)
        idy = np.argsort(-wy)
        wx = wx[idx]
        vx = vx[:, idx]
        wy = wy[idy]
        vy = vy[:, idy]
        vx = vx[:, wx > np.max(wx) * self.thresh]
        wx = wx[wx > np.max(wx) * self.thresh]
        vy = vy[:, wy > np.max(wy) * self.thresh]
        wy = wy[wy > np.max(wy) * self.thresh]
        vx = vx.dot(np.diag(np.sqrt(wx)))
        vy = vy.dot(np.diag(np.sqrt(wy)))

        T = Kx.shape[0]
        num_eigx = vx.shape[1]
        num_eigy = vy.shape[1]
        size_u = num_eigx * num_eigy
        uu = np.zeros((T, size_u))
        for i in range(0, num_eigx):
            for j in range(0, num_eigy):
                uu[:, i * num_eigy + j] = vx[:, i] * vy[:, j]

        if size_u > T:
            uu_prod = uu.dot(uu.T)
        else:
            uu_prod = uu.T.dot(uu)
        return uu_prod, size_u


    def get_kappa(self, uu_prod):
        mean_appr = np.trace(uu_prod)
        var_appr = 2 * np.trace(uu_prod.dot(uu_prod))
        k_appr = mean_appr ** 2 / var_appr
        theta_appr = var_appr / mean_appr
        return k_appr, theta_appr

    def null_sample_spectral(self, uu_prod, size_u, T):
        from numpy.linalg import eigvalsh

        eig_uu = eigvalsh(uu_prod)
        eig_uu = -np.sort(-eig_uu)
        eig_uu = eig_uu[0:np.min((T, size_u))]
        eig_uu = eig_uu[eig_uu > np.max(eig_uu) * self.thresh]

        f_rand = np.random.chisquare(1, (eig_uu.shape[0], self.null_samples))
        null_dstr = eig_uu.T.dot(f_rand)
        return null_dstr