'''
Define custom losses and test-statistics here.
'''

import torch
import numpy as np


def compute_pdist_sq(x, y=None):
    """compute the squared paired distance between x and y."""
    if y is not None:
        x_norm = (x ** 2).sum(1).view(-1, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
        return torch.clamp(x_norm + y_norm - 2.0 * x @ y.T, min=0)
    a = x.view(x.shape[0], -1)
    aTa = torch.mm(a, a.T)
    aTa_diag = torch.diag(aTa)
    aTa = torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)

    ind = torch.triu_indices(x.shape[0], x.shape[0], offset=1, device=x.device)
    aTa[ind[0], ind[1]] = 0
    return aTa + aTa.transpose(0, 1)


def gaussian_kernel(X, sigma2=1.0, Y=None, normalized=False, **ignored):
    if normalized:
        X = X / torch.linalg.norm(X, dim=1, keepdim=True)
        if Y is not None:
            Y = Y / torch.linalg.norm(Y, dim=1, keepdim=True)
    Dxx = compute_pdist_sq(X, Y)
    if sigma2 is None:
        sigma2 = Dxx.median()
    Kx = torch.exp(-Dxx / sigma2)
    return Kx


def gaussian_rff_kernel(X, sigma2, rff_dim=128, **ignored):
    raise NotImplementedError('add second variable')
    # todo: add seed, add second variable
    omegas = torch.randn(rff_dim, X.shape[-1], device=X.device, dtype=torch.float)
    bias = torch.rand(rff_dim, device=X.device) * 2.0 * np.pi
    features = X.mm(omegas.T) / np.sqrt(sigma2)
    features = torch.cos(features + bias) * np.sqrt(2 / rff_dim)
    # todo: return only features
    return features.mm(features.T)


def cosine_kernel(X, **_ignored):
    raise NotImplementedError('add second variable')
    #TODO
    norm = torch.linalg.norm(X, dim=1)
    pass


def linear_kernel(X, intercept=0.0, Y=None, **_ignored):
    if Y is None:
        Y = X
    Kx = X @ Y.transpose(1, 0) + intercept
    
    return Kx


def polynomial_kernel(X, intercept=1.0, Y=None, p=1.0, **_ignored):
    if Y is None:
        Y = X
    Kx = (X @ Y.transpose(1, 0) + intercept) ** p

    return Kx


def hsic_matrices(Kx, Ky, biased=False):
    n = Kx.shape[0]

    if biased:
        a_vec = Kx.mean(dim=0)
        b_vec = Ky.mean(dim=0)
        # same as tr(HAHB)/m^2 for A=a_matrix, B=b_matrix, H=I - 11^T/m (centering matrix)
        return (Kx * Ky).mean() - 2 * (a_vec * b_vec).mean() + a_vec.mean() * b_vec.mean()

    else:
        tilde_Kx = Kx - torch.diagflat(torch.diag(Kx))
        tilde_Ky = Ky - torch.diagflat(torch.diag(Ky))

        u = tilde_Kx * tilde_Ky
        k_row = tilde_Kx.sum(dim=1)
        l_row = tilde_Ky.sum(dim=1)
        mean_term_1 = u.sum()  # tr(KL)
        mean_term_2 = k_row.dot(l_row)  # 1^T KL 1
        mu_x = tilde_Kx.sum()
        mu_y = tilde_Ky.sum()
        mean_term_3 = mu_x * mu_y

        # Unbiased HISC.
        mean = 1 / (n * (n - 3)) * (mean_term_1 - 2. / (n - 2) * mean_term_2 + 1 / ((n - 1) * (n - 2)) * mean_term_3)
        return mean

def hsic(X, Y, kernelX='gaussian', kernelX_params=None, kernelY='linear', kernelY_params=None, biased=False):
    '''X ind. Y'''
    # todo:
    #  alternative implementation for RFF
    #  biased/unbiased HSIC choice
    #  faster implementation for biased
    Kx = eval(f'{kernelX}_kernel(X, **kernelX_params)')
    Ky = eval(f'{kernelY}_kernel(Y, **kernelY_params)')

    return hsic_matrices(Kx, Ky, biased)

def hscic(X, Y, Z, ridge_lambda, kernelX='gaussian', kernelX_params=None,
          kernelY='gaussian', kernelY_params=None, kernelZ='gaussian', kernelZ_params=None):
    '''X ind. Y | Z '''
    # todo:
    #  swap Z and Y
    #  alternative implementation for RFF
    Kx = eval(f'{kernelX}_kernel(X, **kernelX_params)')
    Ky = eval(f'{kernelY}_kernel(Y, **kernelY_params)')
    Kz = eval(f'{kernelZ}_kernel(Z, **kernelZ_params)')

    # https://arxiv.org/pdf/2207.09768.pdf
    WtKzz = torch.linalg.solve(Kz + ridge_lambda  * torch.eye(Kz.shape[0]).to(Kz.device), Kz) # * Kz.shape[0] for ridge_lambda
    # todo:
    #   SVD + LOO here? but that doesn't scale with # params
    #   three LOO for all three regressions?
    #   re-do it every few iters?
    # sum_i A_(i.)B_(.i) = tr(AB) = (A * B^T).sum()
    # A = Kzz^T, B = the other one, so the transposes cancel out
    term_1 = (WtKzz * ((Kx * Ky) @ WtKzz)).sum() # tr(WtKzz.T @ (Kx * Ky) @ WtKzz)
    WkKxWk = WtKzz * (Kx @ WtKzz)
    KyWk = Ky @ WtKzz
    term_2 = (WkKxWk * KyWk).sum()
    # here it's crucial that the first dimension is the batch of other matrices
    term_3 = (WkKxWk.sum(dim=0) * (WtKzz * KyWk).sum(dim=0)).sum()

    return (term_1 - 2 * term_2 + term_3) / Kz.shape[0]


def hsic_corrected_no_precomuted(X, Y, Y_heldout, Z, Z_heldout, ridge_lambda, kernelX='gaussian', kernelX_params=None,
                                 kernelY='gaussian', kernelY_params=None, kernelZ='gaussian', kernelZ_params=None,
                                 biased=False):
    '''X ind. Y | Z '''
    # todo:
    #  swap Z and Y
    #  alternative implementation for RFF
    #  different kernels
    if Z_heldout is None:
        Z_heldout = Z.clone()
        Y_heldout = Y.clone()

    Kx = eval(f'{kernelX}_kernel(X, **kernelX_params)')

    Y_all = torch.vstack((Y, Y_heldout))
    Ky_all = eval(f'{kernelY}_kernel(Y_all, **kernelY_params)')

    Z_all = torch.vstack((Z, Z_heldout))
    Kz_all = eval(f'{kernelZ}_kernel(Z_all, **kernelZ_params)')

    n_points = Z.shape[0]
    n_heldout = Z_heldout.shape[0]

    WtKzz = torch.linalg.solve(Kz_all[n_points:, n_points:] + ridge_lambda  * torch.eye(n_heldout, device=Z.device),
                               Kz_all[n_points:, :n_points])
    Ky_mod = Ky_all[:, n_points:] @ WtKzz
    Kres = Ky_all[:n_points, :n_points] - Ky_mod[:n_points] - Ky_mod[:n_points].T + WtKzz.T @ Ky_mod[n_points:]

    return hsic_matrices(Kx, Kres, biased)# * Kz_all[:n_points, :n_points])


def hsic_corrected(X, Y, Y_heldout, Z, Z_heldout, W_1, W_2, kernelX='gaussian', kernelX_params=None,
                   kernelY='gaussian', kernelY_params=None, kernelZ='gaussian', kernelZ_params=None,
                   biased=False, cond_cov=False):
    '''X ind. Y | Z '''
    # todo:
    #  swap Z and Y
    #  alternative implementation for RFF

    Y_all = torch.vstack((Y, Y_heldout))
    Ky_all = eval(f'{kernelY}_kernel(Y_all, Y=Y, **kernelY_params)')  # n_all x n_batch
    # del Y_all

    Kz_all = eval(f'{kernelZ}_kernel(Z_heldout, Y=Z, **kernelZ_params)')  # n_heldout x n_batch

    n_points = Z.shape[0]

    A = (0.5 * Kz_all.T @ W_2 - Ky_all[n_points:, :].T) @ W_1 @ Kz_all
    # del Kz_all
    Kres = Ky_all[:n_points, :n_points] + A + A.T
    # del Ky_all, A

    Kx = eval(f'{kernelX}_kernel(X, **kernelX_params)')

    if cond_cov:
        Kx = Kx * Kres * Ky_all[:n_points, :]
        if biased:
            return Kx.mean()
        idx = torch.triu_indices(n_points, n_points, 1)
        return Kx[idx[0], idx[1]].mean()
    return hsic_matrices(Kx, Kres * Ky_all[:n_points, :], biased)# * Kz_all[:n_points, :n_points])


def gcm(x, fz, y, gz):
    '''
    Generalized covariance measure for multivariate X & Y.
    From https://arxiv.org/abs/1804.07203, Eq.(3)
    '''
    n = x.shape[0]

    residual_x = x - fz
    residual_y = y - gz
    R = torch.bmm(residual_x.unsqueeze(-1), residual_y.unsqueeze(1))
    R_avg = R.mean(dim=0)
    tau_N = np.sqrt(n) * R_avg
    tau_D = torch.sqrt((R ** 2).mean(dim=0) - (R_avg ** 2))
    T_n = torch.div(tau_N, tau_D + 1e-10)

    ## uncomment to compute covariance matrix of random variable T_n.
    # dx = x.shape[1]; dy = y.shape[1]
    # T_Sigma = Tn_cov(R, R_avg, n, dx, dy)

    S_n = torch.abs(T_n).max()
    return S_n


def Tn_cov(R, R_avg, n, dx, dy):
    '''
    Computes the covariance matrix for multivariate gaussian variable T_n.
    From https://arxiv.org/abs/1804.07203, Sec. 3.2
    '''
    Sigma = -1 * torch.ones((dx, dy, dx, dy)).to(R.device)
    ## TODO: batch-ify for efficiency.
    for j in range(dx):
        for k in range(dy):
            for l in range(dx):
                for m in range(dy):
                    Sigma[j,k,l,m] = (R[:,j,k].dot(R[:,l,m]) / n) - (R_avg[j,k]*R_avg[l,m])
                    Sigma[j,k,l,m] /= torch.sqrt(torch.linalg.norm(R[:,j,k])**2/n - R_avg[j,k]**2) * torch.sqrt(torch.linalg.norm(R[:,l,m])**2/n - R_avg[l,m]**2)

    return Sigma
