import torch
from torch import nn


def anderson(f, x0, m=5, lam=1e-2, max_iter=10, tol=1e-5, beta = 1):
    """ Anderson acceleration for fixed point iteration. """
    bsz, d, H = x0.shape

    X = torch.zeros(bsz, m, d*H, dtype=x0.dtype, device=x0.device)
    F = torch.zeros(bsz, m, d*H, dtype=x0.dtype, device=x0.device)

    X[:,0], F[:,0] = x0.reshape(bsz, -1), f(x0).reshape(bsz, -1)
    X[:,1], F[:,1] = F[:,0], f(F[:,0].view_as(x0)).reshape(bsz, -1)

    H = torch.zeros(bsz, m+1, m+1, dtype=x0.dtype, device=x0.device)
    H[:,0,1:] = H[:,1:,0] = 1
    y = torch.zeros(bsz, m+1, 1, dtype=x0.dtype, device=x0.device)
    y[:,0] = 1

    res = []
    for k in range(2, max_iter+1):
        n = min(k, m)


        G = F[:,:n]-X[:,:n]
        H[:,1:n+1,1:n+1] = torch.bmm(G,G.transpose(1,2)) + lam*torch.eye(n, dtype=x0.dtype,device=x0.device)[None]
        alpha = torch.linalg.lstsq(H[:,:n+1,:n+1],y[:,:n+1])[0][:, 1:n+1, 0]
        X[:,k%m] = beta * (alpha[:,None] @ F[:,:n])[:,0] + (1-beta)*(alpha[:,None] @ X[:,:n])[:,0]
        F[:,k%m] = f(X[:,k%m].view_as(x0)).reshape(bsz, -1)
        #print(torch.linalg.norm(X[:,k%m]))
        res.append((F[:,k%m] - X[:,k%m]).norm().item()/(1e-5 + F[:,k%m].norm().item()))
        if (res[-1] < tol):
            break
    return X[:,k%m].view_as(x0), res,k


class SpectralConv1d(nn.Module):
    '''FFT, linear transform, and Inverse FFT
    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        modes (int): Number of Fourier modes
    [paper](https://arxiv.org/abs/2010.08895)
    '''
    def __init__(self, in_channels: int, out_channels: int, modes: int):
        super().__init__()

        self.out_channels = out_channels
        self.modes = modes
        scale = 1 / (in_channels * out_channels)
        std = torch.sqrt(torch.tensor(2.0 / in_channels, dtype=torch.float32))
        self.weights = nn.Parameter(
            torch.empty(in_channels, out_channels, modes, 2, dtype=torch.float32).normal_(0, std)
        )



    def batchmul1d(self, input, weights):
        ## (batch, in_channel, x), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix, iox -> box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device,
        )
        out_ft[:, :, : self.modes] = self.batchmul1d(x_ft[:, :, : self.modes], torch.view_as_complex(self.weights))

        # Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x
    
    
class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP, self).__init__()
        self.mlp1 = nn.Conv1d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv1d(mid_channels, out_channels, 1)
        self.activation = nn.GELU()

    def forward(self, x):
        x = self.mlp1(x)
        x = self.activation(x)
        x = self.mlp2(x)
        return x


class FourierLayer1D(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, modes: int, bias: bool=True):
        super().__init__()

        self.spectral_conv = SpectralConv1d(in_channels, out_channels, modes)
        self.pointwise_conv = MLP(in_channels, out_channels, mid_channels=in_channels)
        self.pointwise_conv_2 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
        self.pointwise_conv_3 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        x1 = self.spectral_conv(x)
        x = self.pointwise_conv(x1) + self.pointwise_conv_2(x)
        return x

class FNO1D_input_injection(nn.Module):
    def __init__(self, hidden_channels: int, modes: int, depth: int=4, bias: bool=True):
        super().__init__()
        layers = []
        for _ in range(depth):
            layers.append(FourierLayer1D(hidden_channels, hidden_channels, modes, bias=bias))
        self.layers = nn.ModuleList(layers)
        self.activation = nn.GELU()


    def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:


        for layer in self.layers:
            x = layer(x) + g
            x = self.activation(x)
            
        return x 
    
class FNO1D(nn.Module):
    def __init__(self, hidden_channels: int, modes: int, depth: int=4, bias: bool=True):
        super().__init__()
        layers = []
        for _ in range(depth):
            layers.append(FourierLayer1D(hidden_channels, hidden_channels, modes, bias=bias))
        self.layers = nn.ModuleList(layers)


    def forward(self, x: torch.Tensor) -> torch.Tensor:


        for layer in self.layers:
            x = layer(x) 
            
        return x 
    
    
class DEQFNO(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, hidden_channels: int, modes: int, depth: int=4, bias: bool=True, s: int=1, tau: float=1.0, max_iter: int=20):
        super().__init__()
        
        self.lift = nn.Conv1d(in_channels, hidden_channels, kernel_size=1, bias=bias)
        self.fno = FNO1D_input_injection(hidden_channels, modes, depth, bias)
        self.p = nn.Identity()
        self.c = nn.Identity()
        self.proj = nn.Conv1d(hidden_channels, out_channels, kernel_size=1, bias=bias)
        self.s = s
        self.tau = tau
        self.max_iter = max_iter

    def forward(self, x: torch.Tensor, num_steps: int=None) -> torch.Tensor:

        x = self.lift(x)
        x = self.p(x)

        num_steps = num_steps if num_steps is not None else self.max_iter

        z = torch.zeros_like(x, device=x.device)

        fun = lambda z: self.fno(z, x)
        with torch.no_grad():
            z, _, _ = anderson(fun, z, m=5, lam=1e-5, max_iter=(num_steps+2), tol=1e-7)
        
        if self.training:
            z.requires_grad_()
            for _ in range(self.s):
                z = (1 - self.tau) * z + self.tau * fun(z)
        
        z = self.c(z)
        
        return self.proj(z)
    

