import pywt
import numpy as np
import torch
import torch.nn as nn
from modules.basic_blocks import MLP
from modules.utils import normalize, Swish

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax.numpy as jnp
from jax.config import config
import neural_tangents as nt
from neural_tangents import stax

class BachKernel(nn.Module):
    def __init__(self, in_channels, s_dim, slice_per_dim=10):
        super(BachKernel, self).__init__()
        self.in_channels = in_channels
        self.s_dim = s_dim
        self.s_count = slice_per_dim ** s_dim
        self.feature_depth = self.s_count

        slices = np.linspace(-1, 1, slice_per_dim)
        coord = np.stack(np.meshgrid(*([slices] * s_dim), indexing='ij'), axis=-1).reshape(-1, s_dim)
        self.s_set = nn.Parameter(torch.Tensor(coord), requires_grad=False)

        self.proj = MLP([s_dim + 2, in_channels * 2, in_channels])
        self.w_mu = nn.Parameter(torch.Tensor(self.s_count, in_channels))
        self.w_sigma = nn.Parameter(torch.Tensor(self.s_count, in_channels))

        # self.w_mu = nn.Parameter(torch.Tensor(in_channels))
        # self.w_sigma = nn.Parameter(torch.Tensor(in_channels))

        nn.init.normal_(self.w_mu, 0., 0.02)
        nn.init.normal_(self.w_sigma, 0., 0.02)

    def forward(self, x, y=None):
        if y is None:
            y = x
        bx, c, h, w = x.shape
        by, _, _, _ = y.shape

        # (h, w, bx/y, s_count, c)
        fx = self.get_f(x).permute(0, 2, 3, 1).view(bx, h * w, 1, c)#.mean(dim=1)#.expand(h * w, self.s_count, bx, c)
        fy = self.get_f(y).permute(0, 2, 3, 1).view(by, h * w, 1, c)#.mean(dim=1)#.expand(h * w, self.s_count, by, c)

        # # (s_count, c)
        gs = self.get_g(self.s_set, h, w).view(1, h * w, self.s_count, c)

        Ewwt = self.get_Ewwt()

        fsx = normalize(fx * gs)
        fsx_Ewwt = torch.einsum('xdsi, sij -> xdsj', fsx, Ewwt)
        

        fsy = normalize(fy * gs)
        fsy_Ewwt = torch.einsum('ydsi, sij -> ydsj', fsy, Ewwt)

        XYT = torch.einsum('xdsi, ydsi -> xy', fsx_Ewwt, fsy)
        XXT_diag = torch.einsum('xdsi, xdsi -> x', fsx_Ewwt, fsx).unsqueeze(1)
        YYT_diag = torch.einsum('ydsi, ydsi -> y', fsy_Ewwt, fsy).unsqueeze(1)
        # YYT = torch.einsum('xdsi, ydsi -> xy', fsy_Ewwt, fsy)
        # YYT_diag = YYT.diag().unsqueeze(1)

        # Variable-bandwidth rbf kernel
        # sigma = 1
        # m = self.s_count
        # Y_distance = YYT_diag - 2 * YYT + YYT_diag.t()
        # Y_distance_exp = (-Y_distance / sigma).exp()
        # theta_Y = (Y_distance_exp.sum(-1) - Y_distance_exp.diag()) / (by * ((np.pi * sigma) ** (m / 2)))
        # theta_Y = theta_Y.unsqueeze(1)

        # X_distance = XXT_diag - 2 * XYT + YYT_diag.t()
        # X_distance_exp = (-X_distance / sigma).exp()
        # theta_X = (X_distance_exp.sum(-1)) / (by * ((np.pi * sigma) ** (m / 2)))
        # theta_X = theta_X.unsqueeze(1)

        # gamma = 1. / (sigma * (theta_X ** (-1/m)) * (theta_Y.t() ** (-1/m)))
        # KXY = torch.exp(-gamma * X_distance)

        # Mixed Fixed-bandwidth rbf kernel
        # sigma_list = [1, 1/2, 1/4, 1/8]
        # sigma_list = [1, np.sqrt(2), 2, 2 * np.sqrt(2), 4]
        # sigma_list = [1, 2, 4, 8, 16, 32, 64]
        # sigma_list = [1 / 2, 1 / np.sqrt(2), 1, np.sqrt(2), 2]
        # sigma_list = [1, 4, 16, 64, 256]
        # sigma_list = [1, np.sqrt(2) / 2, 1 / 2, np.sqrt(2) / 4, 1 / 4]
        # sigma_list = [1]
        # sigma_list = [np.exp(2), 1, np.exp(-2), np.exp(-4)]
        sigma_list = [1, np.exp(-2), np.exp(-3), np.exp(-4)]
        # sigma_list = [np.exp(-4)]
        # sigma_list = [1, 1 / 2, 1 / 4, 1 / 8, 1 / 16]

        exponent = XXT_diag - 2 * XYT + YYT_diag.t()

        KXY = 0.
        for sigma in sigma_list:
            gamma = 1.0 / (2 * sigma**2)
            KXY += torch.exp(-gamma * exponent)
        KXY /= len(sigma_list)

        # KXY = XYT / (h * w * self.s_count)

        # KXY = torch.einsum('xdsi, ydsi -> xy', mid, psi_y) / (h * w * self.s_count)
        # KXY = mid.reshape(bx, -1).mm(psi_y.reshape(by, -1).transpose(-2, -1)) / (h * w * self.s_count)

        # fx = self.get_feature(x).view(bx, -1)
        # # fy = self.get_feature(y).view(by, -1)
        # fy = y.view(by, -1)

        # KXY = fx.mm(fy.transpose(-2, -1))
        # KXY = mix_rbf_kernel(fx, fy)[:bx, bx:]
        
        # KXY = KXY / KXY.sum(-1, keepdim=True)

        return KXY
    
    def get_f(self, x):
        # Currently assume x is already f(x)
        return x
    
    def get_g(self, s, h, w):
        slice_h = np.linspace(-1, 1, h)
        slice_w = np.linspace(-1, 1, w)
        loc = np.stack(np.meshgrid(slice_h, slice_w, indexing='ij'), axis=-1).reshape(h * w, 2)

        loc_t = torch.Tensor(loc).to(s.device).view(h * w, 1, 2).expand(h * w, self.s_count, 2).reshape(h * w * self.s_count, 2)
        s_t = s.view(1, self.s_count, self.s_dim).expand(h * w, self.s_count, self.s_dim).reshape(h * w * self.s_count, self.s_dim)

        s_w_loc = torch.cat((loc_t, s_t), dim=-1)
        g_s = self.proj(s_w_loc)

        return g_s.view(h * w, self.s_count, -1)#.softmax(dim=-1)
        # return self.proj(s)
    
    def get_Ewwt(self):
        mu_normed = normalize(self.w_mu)
        mu_mut = mu_normed.unsqueeze(2).bmm(mu_normed.unsqueeze(1))
        sigma = torch.diag_embed(torch.sigmoid(self.w_sigma + 2.) + 0.1)

        Ewwt = (mu_mut + sigma).view(self.s_count, self.in_channels, self.in_channels)

        # mu_mut = self.w_mu.unsqueeze(1).mm(self.w_mu.unsqueeze(0))
        # sigma = torch.diag_embed(self.w_sigma)

        # Ewwt = (mu_mut + sigma).view(self.in_channels, self.in_channels)
        return Ewwt

    def get_feature(self, x):
        bx, c, h, w = x.shape
        mu_normed = normalize(self.w_mu)
        w_mean_reshaped = mu_normed.view(1, 1, 1, self.s_count, c)

        fx = self.get_f(x).permute(0, 2, 3, 1).view(bx, h, w, 1, c)#.expand(bx, h, w, self.s_count, c)#.view(bx * h * w * self.s_count, 1, c)
        gx = self.get_g(self.s_set, h, w).view(1, h, w, self.s_count, c)#.expand(bx, h, w, self.s_count, c)#.view(bx * h * w * self.s_count, 1, c)

        feature = (normalize(fx * gx) * w_mean_reshaped).sum(dim=-1)

        feature = feature.view(bx, h, w, self.s_count).permute(0, 3, 1, 2).contiguous()
        return feature
    
    def diag(self, x):
        return torch.ones(x.size(0),).to(x.device)

class NeuralKernel(nn.Module):
    def __init__(self, in_channels, enc_channels, out_channels=None, n_layers=2, normalize=False, kernel_type='ntk', activation='erf'):
        super(NeuralKernel, self).__init__()

        assert kernel_type in ['ntk', 'nngp']
        assert activation in ['erf', 'relu']
        act_dict = {'erf': stax.Erf(), 'relu': stax.Relu()}
        if out_channels is None:
            out_channels = enc_channels
        layers = []
        for i in range(n_layers - 1):
            layers.append(stax.Dense(enc_channels))
            layers.append(act_dict[activation])
        layers.append(stax.Dense(out_channels))
        # layers.append(stax.Rbf(np.sqrt(out_channels)))
        self.kernel_fn = stax.serial(*layers)[-1]
        self.normalize = normalize
        self.feature_depth = in_channels
        self.kernel_type = kernel_type
    
    def forward(self, x, y=None):
        if y is None:
            y = x

        x = self.get_feature(x)
        y = self.get_feature(y)

        K = self._get_kernel(x, y)
        return K
    
    # Not differentiable for the moment
    def _get_kernel(self, x, y):
        # convert into jax array
        x_jnp = jnp.asarray(x.detach().cpu().numpy().astype(np.float32))
        y_jnp = jnp.asarray(y.detach().cpu().numpy().astype(np.float32))

        K = self.kernel_fn(x_jnp, y_jnp, self.kernel_type)
        # convert back to pytorch
        K = torch.tensor(np.copy(K), device=x.device)
        return K
        
    def get_feature(self, x):
        if self.normalize:
            return normalize(x)
        return x
    
    def diag(self, x, batch_size=5000):
        n = x.size(0)
        batch_size = min(n, batch_size)
        diag = []
        for i in range(n // batch_size + np.sign(n % batch_size)):
            x_batch = x[i * batch_size: (i + 1) * batch_size]
            diag.append(self._get_kernel(x_batch, x_batch).diag())
        return torch.cat(diag, dim=0)


class RandFeatureKernel(nn.Module):
    def __init__(self, in_channels, enc_channels, out_channels=None, n_layers=2, activation=nn.Tanh, kernel_fn=lambda c: PolynomialKernel(c, deg=1, c=0)):
        super(RandFeatureKernel, self).__init__()
        if out_channels is None:
            out_channels = enc_channels
        self.rand_proj = MLP((in_channels, *([enc_channels] * (n_layers - 1)), out_channels), activation=activation, batchnorm=False, zero_output=False)
        self.kernel_obj = kernel_fn(enc_channels)
        self.feature_depth = self.kernel_obj.feature_depth
    
    def forward(self, x, y=None, requires_grad=False):
        if y is None:
            y = x
        bx = x.size(0)
        by = y.size(0)

        x = self.rand_proj(self.get_feature(x))
        y = self.rand_proj(self.get_feature(y))

        x -= x.mean(dim=0).unsqueeze(0)
        y -= y.mean(dim=0).unsqueeze(0)

        KXY = self.kernel_obj(x.view(bx, -1), y.view(by, -1))
        return KXY
    
    def get_feature(self, x):
        # return self.kernel_obj.get_feature(x)
        return x

    def diag(self, x):
        return self.kernel_obj.diag(self.rand_proj(self.get_feature(x)))
    
class MixRBFKernel(nn.Module):
    def __init__(self, in_channels, sigma=[1, np.sqrt(2), 2, 2 * np.sqrt(2), 4], normalize=False):
        super(MixRBFKernel, self).__init__()
        self.feature_depth = in_channels
        self.sigma = sigma
        self.normalize = normalize

    def forward(self, x, y=None):
        if y is None:
            y = x
        bx = x.size(0)
        by = y.size(0)

        x = self.get_feature(x)
        y = self.get_feature(y)

        KXY = mix_rbf_kernel(x.view(bx, -1), y.view(by, -1), self.sigma)
        return KXY
    
    def diag(self, x):
        b = x.size(0)
        return torch.ones(b,).to(x.device)

    def get_feature(self, x):
        if self.normalize:
            return normalize(x)
        return x

class ExpKernel(nn.Module):
    def __init__(self, in_channels, tau=1., normalize=False):
        super(ExpKernel, self).__init__()
        self.feature_depth = in_channels
        self.tau = 1.
        self.normalize = normalize

    def forward(self, x, y=None):
        if y is None:
            y = x
        bx = x.size(0)
        by = y.size(0)

        x = self.get_feature(x)
        y = self.get_feature(y)

        KXY = exp_kernel(x.view(bx, -1), y.view(by, -1), self.tau)
        return KXY
    
    def diag(self, x):
        b = x.size(0)
        return torch.exp(x.view(b, -1).pow(2).sum(-1) / self.tau)

    def get_feature(self, x):
        if self.normalize:
            return normalize(x)
        return x
    
class PolynomialKernel(nn.Module):
    def __init__(self, in_channels, deg=2, c=1, normalize=False):
        super(PolynomialKernel, self).__init__()
        self.deg = deg
        self.c = c
        self.feature_depth = in_channels
        self.normalize = normalize

    def forward(self, x, y=None):
        if y is None:
            y = x
        bx = x.size(0)
        by = y.size(0)

        x = self.get_feature(x)
        y = self.get_feature(y)

        KXY = polynomial_kernel(x.view(bx, -1), y.view(by, -1), deg=self.deg, c=self.c)
        return KXY

    def get_feature(self, x):
        if self.normalize:
            return normalize(x)
        return x

class ArccosKernel(nn.Module):
    def __init__(self, in_channels, layers=1, deg=0, normalize=False):
        super(ArccosKernel, self).__init__()
        self.layers = layers
        self.deg = deg
        self.feature_depth = in_channels
        self.normalize = normalize

    def forward(self, x, y=None):
        if y is None:
            y = x

        bx = x.size(0)
        by = y.size(0)

        x = self.get_feature(x)
        y = self.get_feature(y)
        
        KXY = arccos_kernel(x.view(bx, -1), y.view(by, -1), self.layers, self.deg, self.normalize)
        return KXY
    
    def diag(self, x):
        b = x.size(0)
        x = self.get_feature(x)
        if self.deg == 0:
            return torch.ones(b,).to(x.device)
        elif self.deg == 1:
            return x.pow(2).view(b, -1).sum(-1)
    
    def get_feature(self, x):
        if self.normalize:
            return normalize(x)
        return x

def mix_rbf_kernel(X, Y, sigma_list=[1, np.sqrt(2), 2, 2 * np.sqrt(2), 4]):
    m = X.size(0)

    XYT = torch.mm(X, Y.t())
    X_norm_sqr = (X * X).sum(-1).unsqueeze(1)
    Y_norm_sqr = (Y * Y).sum(-1).unsqueeze(1)
    exponent = X_norm_sqr - 2 * XYT + Y_norm_sqr.t()

    K = 0.
    for sigma in sigma_list:
        gamma = 1.0 / (2 * sigma**2)
        K += torch.exp(-gamma * exponent)
    K /= len(sigma_list)

    return K

def mix_rbf_kernel_full(X, Y, sigma_list=[1, np.sqrt(2), 2, 2 * np.sqrt(2), 4]):
    m = X.size(0)

    Z = torch.cat((X, Y), 0)
    ZZT = torch.mm(Z, Z.t())
    diag_ZZT = torch.diag(ZZT).unsqueeze(1)
    Z_norm_sqr = diag_ZZT.expand_as(ZZT)
    exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

    K = 0.
    for sigma in sigma_list:
        gamma = 1.0 / (2 * sigma**2)
        K += torch.exp(-gamma * exponent)
    K /= len(sigma_list)
    
    return K[:m, :m], K[:m, m:], K[m:, m:]

def dot_prod_kernel(X, Y):
    return polynomial_kernel(X, Y, deg=1, c=0)

def polynomial_kernel(X, Y, deg=1, c=1):
    XYT = torch.mm(X, Y.t())
    K = torch.pow(XYT + c, deg)
    return K

def polynomial_kernel_full(X, Y, scale=1., deg=1, c=1):
    m = X.size(0)
    Z = torch.cat((X, Y), 0)
    ZZT = torch.mm(Z, Z.t())
    K = torch.pow(scale * ZZT + c, deg)
    return K[:m, :m], K[:m, m:], K[m:, m:]

def cosine_sim_kernel(X, Y):
    Z = torch.cat((X, Y), dim=0)
    ZZT = torch.mm(Z, Z.t())
    Z_norm_sqrt = torch.diag(ZZT).unsqueeze(1).sqrt().expand_as(ZZT)
    K = ZZT / (Z_norm_sqrt * Z_norm_sqrt.t())
    return K

def exp_kernel(X, Y, tau=1.):
    XYT = torch.mm(X, Y.t())
    K = torch.exp(XYT / tau)
    return K

def arccos_kernel(X, Y, layers=1, deg=1, normalized=False):
    assert deg >= 0 and deg <= 2

    if deg == 0:
        J = lambda theta: 1. - theta / np.pi
    elif deg == 1:
        J = lambda theta: (torch.sin(theta) + (np.pi - theta) * torch.cos(theta)) / np.pi
    else:
        J = lambda theta: (3. * torch.sin(theta) * torch.cos(theta)
              + (np.pi - theta) * (1 + 2. * (torch.cos(theta).pow(2)))) / np.pi

    return _arccos_kernel_recursive(X, Y, layers, deg, J, normalized=normalized)

def _arccos_kernel_recursive(X, Y, l, deg, J, normalized=False):
    if l == 1:
        if normalized:
            nxny = 1.
        else:
            nxny = torch.norm(X, p=2, dim=-1).unsqueeze(1) * torch.norm(Y, p=2, dim=-1).unsqueeze(0)
        xty = torch.mm(X, Y.t())
    elif l > 1:
        if normalized:
            nxny = 1.
        else:
            nxny = torch.sqrt(_arccos_kernel_recursive(X, X, l - 1, deg, J).diag().unsqueeze(1)
                            * _arccos_kernel_recursive(Y, Y, l - 1, deg, J).diag().unsqueeze(0))
        xty = _arccos_kernel_recursive(X, Y, l - 1, deg, J)

    a = torch.acos(torch.clamp(xty / nxny, min=-1. + 1e-7, max=1. - 1e-7))
    K = (nxny ** deg) * J(a)
    return K
