import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_, kaiming_normal_, kaiming_uniform_
from torch import Tensor
from typing import List, Tuple
import math
from functools import partial
from scipy.special import eval_legendre
from sympy import Poly, legendre, Symbol, chebyshevt

'''
modified from the official implementation of MWO:
https://arxiv.org/pdf/2109.13459.pdf
'''

def get_initializer(name='xavier_normal'):
    if name == 'xavier_normal':
        init_ = partial(xavier_normal_)
    elif name == 'kaiming_uniform':
        init_ = partial(kaiming_uniform_)
    elif name == 'kaiming_normal':
        init_ = partial(kaiming_normal_)
    return init_

def legendreDer(k, x):
    def _legendre(k, x):
        return (2*k+1) * eval_legendre(k, x)
    out = 0
    for i in np.arange(k-1, -1, -2):
        out += _legendre(i, x)
    return out


def phi_(phi_c, x, lb=0, ub=1):
    mask = np.logical_or(x < lb, x > ub) * 1.0
    return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask)


def get_phi_psi(k, base):

    x = Symbol('x')
    phi_coeff = np.zeros((k, k))
    phi_2x_coeff = np.zeros((k, k))
    if base == 'legendre':
        for ki in range(k):
            coeff_ = Poly(legendre(ki, 2*x-1), x).all_coeffs()
            phi_coeff[ki, :ki+1] = np.flip(np.sqrt(2*ki+1)
                                           * np.array(coeff_).astype(np.float64))
            coeff_ = Poly(legendre(ki, 4*x-1), x).all_coeffs()
            phi_2x_coeff[ki, :ki+1] = np.flip(np.sqrt(2) * np.sqrt(
                2*ki+1) * np.array(coeff_).astype(np.float64))

        psi1_coeff = np.zeros((k, k))
        psi2_coeff = np.zeros((k, k))
        for ki in range(k):
            psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
            for i in range(k):
                a = phi_2x_coeff[ki, :ki+1]
                b = phi_coeff[i, :i+1]
                prod_ = np.convolve(a, b)
                prod_[np.abs(prod_) < 1e-8] = 0
                proj_ = (prod_ * 1/(np.arange(len(prod_))+1) *
                         np.power(0.5, 1+np.arange(len(prod_)))).sum()
                psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
                psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]
            for j in range(ki):
                a = phi_2x_coeff[ki, :ki+1]
                b = psi1_coeff[j, :]
                prod_ = np.convolve(a, b)
                prod_[np.abs(prod_) < 1e-8] = 0
                proj_ = (prod_ * 1/(np.arange(len(prod_))+1) *
                         np.power(0.5, 1+np.arange(len(prod_)))).sum()
                psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
                psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]

            a = psi1_coeff[ki, :]
            prod_ = np.convolve(a, a)
            prod_[np.abs(prod_) < 1e-8] = 0
            norm1 = (prod_ * 1/(np.arange(len(prod_))+1) *
                     np.power(0.5, 1+np.arange(len(prod_)))).sum()

            a = psi2_coeff[ki, :]
            prod_ = np.convolve(a, a)
            prod_[np.abs(prod_) < 1e-8] = 0
            norm2 = (prod_ * 1/(np.arange(len(prod_))+1) *
                     (1-np.power(0.5, 1+np.arange(len(prod_))))).sum()
            norm_ = np.sqrt(norm1 + norm2)
            psi1_coeff[ki, :] /= norm_
            psi2_coeff[ki, :] /= norm_
            psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
            psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0

        phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)]
        psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)]
        psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)]

    elif base == 'chebyshev':
        for ki in range(k):
            if ki == 0:
                phi_coeff[ki, :ki+1] = np.sqrt(2/np.pi)
                phi_2x_coeff[ki, :ki+1] = np.sqrt(2/np.pi) * np.sqrt(2)
            else:
                coeff_ = Poly(chebyshevt(ki, 2*x-1), x).all_coeffs()
                phi_coeff[ki, :ki+1] = np.flip(2/np.sqrt(np.pi)
                                               * np.array(coeff_).astype(np.float64))
                coeff_ = Poly(chebyshevt(ki, 4*x-1), x).all_coeffs()
                phi_2x_coeff[ki, :ki+1] = np.flip(np.sqrt(2) * 2 / np.sqrt(
                    np.pi) * np.array(coeff_).astype(np.float64))

        phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)]

        x = Symbol('x')
        kUse = 2*k
        roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
        x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
        # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
        # not needed for our purpose here, we use even k always to avoid
        wm = np.pi / kUse / 2

        psi1_coeff = np.zeros((k, k))
        psi2_coeff = np.zeros((k, k))

        psi1 = [[] for _ in range(k)]
        psi2 = [[] for _ in range(k)]

        for ki in range(k):
            psi1_coeff[ki, :] = phi_2x_coeff[ki, :]
            for i in range(k):
                proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
                psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :]
                psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :]

            for j in range(ki):
                proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum()
                psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :]
                psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :]

            psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5)
            psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1)

            norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum()
            norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum()

            norm_ = np.sqrt(norm1 + norm2)
            psi1_coeff[ki, :] /= norm_
            psi2_coeff[ki, :] /= norm_
            psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0
            psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0

            psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5+1e-16)
            psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5+1e-16, ub=1)

    return phi, psi1, psi2


def get_filter(base, k):

    def psi(psi1, psi2, i, inp):
        mask = (inp <= 0.5) * 1.0
        return psi1[i](inp) * mask + psi2[i](inp) * (1-mask)

    if base not in ['legendre', 'chebyshev']:
        raise Exception('Base not supported')

    x = Symbol('x')
    H0 = np.zeros((k, k))
    H1 = np.zeros((k, k))
    G0 = np.zeros((k, k))
    G1 = np.zeros((k, k))
    PHI0 = np.zeros((k, k))
    PHI1 = np.zeros((k, k))
    phi, psi1, psi2 = get_phi_psi(k, base)
    if base == 'legendre':
        roots = Poly(legendre(k, 2*x-1)).all_roots()
        x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
        wm = 1/k/legendreDer(k, 2*x_m-1)/eval_legendre(k-1, 2*x_m-1)

        for ki in range(k):
            for kpi in range(k):
                H0[ki, kpi] = 1 / \
                    np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
                G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1,
                                                       psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
                H1[ki, kpi] = 1 / \
                    np.sqrt(2) * (wm * phi[ki]((x_m+1)/2)
                                  * phi[kpi](x_m)).sum()
                G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1,
                                                       psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()

        PHI0 = np.eye(k)
        PHI1 = np.eye(k)

    elif base == 'chebyshev':
        x = Symbol('x')
        kUse = 2*k
        roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots()
        x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64)
        # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
        # not needed for our purpose here, we use even k always to avoid
        wm = np.pi / kUse / 2

        for ki in range(k):
            for kpi in range(k):
                H0[ki, kpi] = 1 / \
                    np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum()
                G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1,
                                                       psi2, ki, x_m/2) * phi[kpi](x_m)).sum()
                H1[ki, kpi] = 1 / \
                    np.sqrt(2) * (wm * phi[ki]((x_m+1)/2)
                                  * phi[kpi](x_m)).sum()
                G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1,
                                                       psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum()

                PHI0[ki, kpi] = (wm * phi[ki](2*x_m) *
                                 phi[kpi](2*x_m)).sum() * 2
                PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) *
                                 phi[kpi](2*x_m-1)).sum() * 2

        PHI0[np.abs(PHI0) < 1e-8] = 0
        PHI1[np.abs(PHI1) < 1e-8] = 0

    H0[np.abs(H0) < 1e-8] = 0
    H1[np.abs(H1) < 1e-8] = 0
    G0[np.abs(G0) < 1e-8] = 0
    G1[np.abs(G1) < 1e-8] = 0

    return H0, H1, G0, G1, PHI0, PHI1


class SpatialConv2d(nn.Module):
    def __init__(self,
                 k, alpha, c=1,
                 nl=1,
                 initializer=None,
                 **kwargs):
        super(SpatialConv2d, self).__init__()

        self.k = k
        self.conv = self.convBlock(k, c*k**2, alpha)
        self.linear = nn.Linear(alpha*k**2, c*k**2)

    def forward(self, x):
        B, Nx, Ny, c, ich = x.shape  # (B, Nx, Ny, c, k**2)
        x = x.view(B, Nx, Ny, -1)
        x = x.permute(0, 3, 1, 2)
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1)
        x = self.linear(x)
        x = x.view(B, Nx, Ny, c, ich)

        return x

    def convBlock(self, k, W, alpha):
        och = alpha * k**2
        net = nn.Sequential(
            nn.Conv2d(W, och, 3, 1, 1),
            nn.ReLU(inplace=True),
        )
        return net



class SpectralConv2d(nn.Module):
    def __init__(self,
                 k,
                 modes,
                 c=1,
                 nl=1,
                 initializer=None,
                 **kwargs):
        super(SpectralConv2d, self).__init__()

        self.modes = modes

        self.weights1 = nn.Parameter(
            torch.zeros(c*k**2, c*k**2,
                        self.modes, self.modes, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(
            torch.zeros(c*k**2, c*k**2,
                        self.modes, self.modes, dtype=torch.cfloat))
        nn.init.xavier_normal_(self.weights1)
        nn.init.xavier_normal_(self.weights2)

        self.linear = nn.Linear(c*k**2, c*k**2)
        self.k = k

    @staticmethod
    def complex_matmul2d(x, w):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", x, w)

    def forward(self, x):
        B, Nx, Ny, c, ich = x.shape  # (B, N, N, c, k^2)

        x = x.view(B, Nx, Ny, -1)
        x = x.permute(0, 3, 1, 2)
        x_fft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        l1 = min(self.modes, Nx//2+1)
        l1l = min(self.modes, Nx//2-1)
        l2 = min(self.modes, Ny//2+1)
        out_ft = torch.zeros(B, c*ich, Nx, Ny//2 + 1,
                             device=x.device, dtype=torch.cfloat)

        out_ft[:, :, :l1, :l2] = self.complex_matmul2d(
            x_fft[:, :, :l1, :l2], self.weights1[:, :, :l1, :l2])
        out_ft[:, :, -l1:, :l2] = self.complex_matmul2d(
            x_fft[:, :, -l1:, :l2], self.weights2[:, :, :l1, :l2])

        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(Nx, Ny))

        x = x.permute(0, 2, 3, 1)
        x = F.relu(x)
        x = self.linear(x)
        x = x.view(B, Nx, Ny, c, ich)
        return x


class MWT_CZ2d(nn.Module):
    def __init__(self,
                 k=3,
                 modes=5,
                 coarsest=0,
                 c=1,
                 base='legendre',
                 initializer=None,
                 debug=False,
                 **kwargs):
        super(MWT_CZ2d, self).__init__()

        self.k = k
        self.coarsest = coarsest
        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
        H0r = H0@PHI0
        G0r = G0@PHI0
        H1r = H1@PHI1
        G1r = G1@PHI1
        H0r[np.abs(H0r) < 1e-8] = 0
        H1r[np.abs(H1r) < 1e-8] = 0
        G0r[np.abs(G0r) < 1e-8] = 0
        G1r[np.abs(G1r) < 1e-8] = 0

        self.A = SpectralConv2d(k, modes, c)
        self.B = SpatialConv2d(k, c, c)
        self.C = SpatialConv2d(k, c, c)

        self.linear = nn.Linear(c*k**2, c*k**2) # T0

        if initializer is not None:
            self.reset_parameters(initializer)

        self.register_buffer('ec_s', torch.Tensor(
            np.concatenate((np.kron(H0, H0).T,
                            np.kron(H0, H1).T,
                            np.kron(H1, H0).T,
                            np.kron(H1, H1).T,
                            ), axis=0)))
        self.register_buffer('ec_d', torch.Tensor(
            np.concatenate((np.kron(G0, G0).T,
                            np.kron(G0, G1).T,
                            np.kron(G1, G0).T,
                            np.kron(G1, G1).T,
                            ), axis=0)))

        self.register_buffer('rc_ee', torch.Tensor(
            np.concatenate((np.kron(H0r, H0r),
                            np.kron(G0r, G0r),
                            ), axis=0)))
        self.register_buffer('rc_eo', torch.Tensor(
            np.concatenate((np.kron(H0r, H1r),
                            np.kron(G0r, G1r),
                            ), axis=0)))
        self.register_buffer('rc_oe', torch.Tensor(
            np.concatenate((np.kron(H1r, H0r),
                            np.kron(G1r, G0r),
                            ), axis=0)))
        self.register_buffer('rc_oo', torch.Tensor(
            np.concatenate((np.kron(H1r, H1r),
                            np.kron(G1r, G1r),
                            ), axis=0)))

        self.debug = debug

    def forward(self, x):

        bsz, Nx, Ny, c, k2 = x.shape  # (B, Nx, Ny, c, k**2)
        n_level = math.floor(np.log2(Nx))

        Ud = torch.jit.annotate(List[Tensor], [])
        Us = torch.jit.annotate(List[Tensor], [])

#         decompose
        for i in range(n_level-self.coarsest):
            d, x = self.wavelet_transform(x)
            Ad = self.A(d)
            Bx = self.B(x)
            Cd = self.C(d)
            Ud += [Ad + Bx]
            Us += [Cd]

        x = self.linear(x.view(bsz, 2**self.coarsest, 2**self.coarsest, -1)).view(
            bsz, 2**self.coarsest, 2**self.coarsest, c, k2)  # coarsest scale transform

#        reconstruct
        for i in range(n_level-1-self.coarsest, -1, -1):
            x = x + Us[i]

            x = torch.cat((x, Ud[i]), -1)
            x = self.evenOdd(x)

        return x

    def wavelet_transform(self, x):
        xa = torch.cat([x[:, ::2, ::2, :, :],
                        x[:, ::2, 1::2, :, :],
                        x[:, 1::2, ::2, :, :],
                        x[:, 1::2, 1::2, :, :]
                        ], -1)
        d = torch.matmul(xa, self.ec_d)
        s = torch.matmul(xa, self.ec_s)
        return d, s

    def evenOdd(self, x):

        B, Nx, Ny, c, ich = x.shape  # (B, Nx, Ny, c, k**2)
        assert ich == 2*self.k**2
        x_ee = torch.matmul(x, self.rc_ee)
        x_eo = torch.matmul(x, self.rc_eo)
        x_oe = torch.matmul(x, self.rc_oe)
        x_oo = torch.matmul(x, self.rc_oo)

        x = torch.zeros(B, Nx*2, Ny*2, c, self.k**2,
                        device=x.device)
        x[:, ::2, ::2, :, :] = x_ee
        x[:, ::2, 1::2, :, :] = x_eo
        x[:, 1::2, ::2, :, :] = x_oe
        x[:, 1::2, 1::2, :, :] = x_oo
        return x

    def reset_parameters(self, initializer):
        initializer(self.linear.weight)


class MWT2d(nn.Module):
    def __init__(self,
                 in_dim=1,
                 k=3,
                 modes=2,
                 c=1,
                 n_layers=3,
                 coarsest=0,
                 spatial_dim=2,
                 base='legendre',
                 initializer=xavier_normal_,
                 out_dim=1,
                 return_latent=False,
                 input_size = (128, 128),
                 debug=False,
                 **kwargs):
        super(MWT2d, self).__init__()

        self.k = k
        self.c = c
        self.clvl = coarsest
        self.n_layers = n_layers
        self.in_proj = nn.Linear(in_dim+spatial_dim, c*k**2)

        self.MWT_CZ = nn.ModuleList(
            [MWT_CZ2d(k, modes, coarsest, c, base,
                      initializer, 
                      debug=debug) for _ in range(n_layers)]
        )

        self.out_proj = nn.Sequential(nn.Linear(c*k**2, 128),
                                      nn.ReLU(),
                                      nn.Linear(128, out_dim))

        if initializer is not None:
            self._reset_parameters(initializer)
        self.return_latent = return_latent
        self.input_size = input_size
        self.debug = debug

    def forward(self, x, gradx=None, grid=None, **inputs):
        x_latent = []
        x = torch.cat([x, gradx, grid], dim=-1)
        # in_dim = 1+2+2
        bsz, *size, _ = x.shape  # (B, Nx, Ny, in_dim)

        if self.input_size:
            x = x.permute(0, 3, 1, 2)
            x = F.interpolate(x, size=self.input_size,
                              mode='bilinear',
                              align_corners=True)
            x = x.permute(0, 2, 3, 1)
        
        Nx, Ny = self.input_size
        
        # ns = math.floor(np.log2(Nx)) # not used
        x = self.in_proj(x)  # channel expansion (*, c*k**2)
        x = x.view(bsz, Nx, Ny, self.c, self.k**2)

        for i in range(self.n_layers):
            x = self.MWT_CZ[i](x)
            if self.return_latent:
                x_latent.append(x.contiguous())
            if i < self.n_layers-1:
                x = F.relu(x)

        x = x.view(bsz, Nx, Ny, -1)  # collapse c and k**2
        x = self.out_proj(x)

        if self.input_size:
            x = x.permute(0, 3, 1, 2)
            x = F.interpolate(x, size=size,
                                mode='bilinear',
                                align_corners=True)

            x = x.permute(0, 2, 3, 1)

        if self.return_latent:
            return dict(preds=x,
                        preds_latent=x_latent)
        else:
            return dict(preds=x)

    def _reset_parameters(self, initializer):
        for params in self.out_proj.parameters():
            if params.ndim > 1:
                initializer(params)
            else:
                nn.init.constant_(params, 0)

if __name__ == "__main__":
    from torchinfo import summary
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    config = dict(
              in_dim=3, # 
              modes=12,
              c=4,  # number of channels = c*k**2
              k=4, # block size
              base='legendre',  # or 'chebyshev'
              n_layers=4,  # number of MWT layers
              L=0, # L=0 means all the way to the coarsest
              initializer=xavier_normal_,
              debug=False
              )
    model = MWT2d(**config)
    model.to(device)
    batch_size, n_grid = 10, 128
    summary(model, input_size=[(batch_size, n_grid, n_grid, 1),
                               (batch_size, n_grid, n_grid, 2),
                               (batch_size, n_grid, n_grid, 2)], device=device)
