import torch
from torch_geometric.nn import MessagePassing
# from torch_geometric.nn import GCNConv
# torch_geometric.nn.
import torch.nn.functional as F
#from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint as odeint
import torch.nn as nn
from torch_geometric.nn import GCNConv, Linear


class GFUNK_layer(nn.Module):
    def __init__(self, in_channels, out_channels, modes): #, redfor, redval
        super(GFNO, self).__init__()

        """
        GFunk layer. It does graph fourier transform, reduction transform, 
        and Inverse graph fourier transform.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.eye(modes, dtype=torch.float))

    def forward(self, x, redfor, redvals):  # redfor is the reduced set of eigenvectors and redvals is the reduced set of eigenvalues
        
        x_gft = torch.matmul(redfor, x) #graph fourier transform
        x_gft = torch.einsum('jm, jk->mjk', torch.stack([redvals**0, redvals, redvals**2], dim=-1), x_gft) #expand across 3 basis of eigenvalues
        x_gft = torch.einsum('jn, mjk->jk', self.weights1, x_gft) # multiply by learned parameters
        x_return = torch.matmul(redfor.transpose(0, 1), x_gft) #inverse graph fourier transform

        return x_return


class GFUNK(torch.nn.Module):
    def __init__(self, modes, width):
        super(gfnogamma, self).__init__()


        self.modes1 = modes
        self.width = width
        self.fco = nn.Linear(1, self.width)

        self.conv0 = GFUNK_layer(self.width, self.width, self.modes1)#, self.redfor, self.redval)
        self.w0 = Linear(self.width, self.width)
        self.conv1 = GFUNK_layer(self.width, self.width, self.modes1)#, self.redfor, self.redval)
        self.w1 = Linear(self.width, self.width)
        self.conv2 = GFUNK_layer(self.width, self.width, self.modes1)#, self.redfor, self.redval)
        self.w2 = Linear(self.width, self.width)


        self.fc1 = Linear(self.width, 32)
        self.fcOut = Linear(32, 1) #self.fcOut = Linear(self.width, 1)


    def forward(self, u, redfor, redvals): 

        x = self.fco(u)

        x1 = self.conv0(x, redfor, redvals)
        x2 = self.w0(x)
        x = F.gelu(x1 + x2)

        x1 = self.conv1(x, redfor, redvals)
        x2 = self.w1(x)
        x = F.gelu(x1 + x2)

        x1 = self.conv2(x, redfor, redvals)
        x2 = self.w2(x)
        x = F.gelu(x1 + x2)

        x = self.fc1(x)
        x = F.gelu(x)
        outSig = self.fcOut(x)

        return outSig


class ODEfunc(nn.Module):
    def __init__(self, model, params=None):
        super().__init__()
        self.model = model
        if params is None:
            self.params = {}
        else:
            self.params = params

    def forward(self, t, u):
        return self.model(u, **self.params)

    def update_params(self, params):
        self.params.update(params)


class ODEBlock(nn.Module):
    def __init__(self, odefunc, method, rtol=1e-7, atol=1e-9):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.rtol = rtol
        self.atol = atol
        self.method = method

    def forward(self, x, t):
        options = {
            #'dtype': torch.float64,
            'step_size': .001 #.00025
            # 'first_step': 1.0e-9,
            # 'grid_points': t,
        }

        adjoint_options = {
            'norm': "seminorm"
        }
        out = odeint(
            self.odefunc,
            x,
            t,
            method=self.method,
            rtol=self.rtol,
            atol=self.atol,
            options=options,
            #adjoint_options=adjoint_options,
        )
        return out


# TODO: Boundary conditions are assumed to be the same on the entire boundary. This can be relaxed.
def get_full_model(modes, width,
        bd_conditions,
        withD,
        msg_nodes,
        msg_dim,
        aggr_nodes,
        int_method='adaptive_heun',
        int_rtol=1e-5, #0.0,
        int_atol=1e-5,
        device='cuda'
):
    """
    Builds and returns the entire ODE message passing model
    Args:
        int_method: method used for the ODE integrator
        int_rtol: relative tolerance used for the ODE integrator
        int_atol: absolute tolerance used for the ODE integrator
        device: cpu, gpu, etc

    Returns: model

    """

    ode_func = ODEfunc(gfnogamma(modes=modes, width=width)).to(device)
    ode_model = ODEBlock(ode_func, int_method, rtol=int_rtol, atol=int_atol).to(device)

    return ode_model
