from typing import Callable
import torch
import torch.nn as nn
from typing import List, Optional, Callable
from model.ivp_solvers.gnn import GNN,GCN
from model.ivp_solvers.mlp import MLP
import model.ivp_solvers as mods
import math
import dgl
import dgl.function as fn
from dgl import batch
import torch.nn.functional as F
import libs.stribor as st

class GCNFlowBlock_test2(nn.Module):
    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 nodes: int,
                 activation,
                 final_activation,
                 time_net: str,
                 time_hidden_dim,
                 n_power_iterations: int = 5,
                 bias:int=True,
                 invertible: bool = True):
        super().__init__()
        wrapper = None
        if invertible:
            wrapper = lambda layer: torch.nn.utils.spectral_norm(layer, n_power_iterations=n_power_iterations)
        nodes = 37 
        #nodes = 10 
        #nodes = 96 
        #nodes = 14 

        self.net2 = MLP(input_size*nodes+3, [32, 32, 64, 32], input_size*nodes, activation, final_activation, wrapper_func=wrapper)
        hidden_size = hidden_size[0]

        self.lin_n = nn.Sequential(wrapper(nn.Linear(input_size, hidden_size)))                                   
        self.lin_r = nn.Sequential(wrapper(nn.Linear(hidden_size, input_size)))

        self.act = nn.ReLU()
        self.time_net = getattr(st.net, time_net)(input_size*nodes, hidden_dim=time_hidden_dim)
    
    def forward(self, nodes, x, h, t_start, t_end, adj): 
        t_output = self.time_net(t_end-t_start)
        if len(adj.shape) == 4:
            h_n = self.lin_n(torch.einsum('btij,btajd->btaid', adj, h))
        elif len(adj.shape) == 5:
            h_n = self.lin_n(torch.einsum('cbtij,cbtjd->cbtid', adj, h))
        elif len(adj.shape) == 3:
            h_n = self.lin_n(torch.einsum('bij,cbtjd->cbtid', adj, h))
        h_output = self.lin_r(self.act(h_n))

        x_output = x + t_output * torch.mul(h_output.view(x.shape), self.net2(torch.cat([x, t_end-t_start, t_start, t_end], -1)))
        return x_output,h_output

class GCNFlowNF(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dims,
        num_layers,
        nodes: int,
        activation='ReLU',
        final_activation=None,
        time_net=None,
        time_hidden_dim=None,
        n_power_iterations=5,
        invertible=True,
        **kwargs
    ):
        super().__init__()

        blocks = []
        for _ in range(num_layers):
            blocks.append(GCNFlowBlock_test2(dim, hidden_dims, nodes, activation, final_activation, time_net,
                                          time_hidden_dim, invertible))
        self.blocks = nn.ModuleList(blocks)


    def forward(self, nodes, x, h, t_start, t_end, adj):
        for block in self.blocks:
            x,h = block(nodes, x, h, t_start, t_end, adj)
        return x,h
    
    def inverse(self, y, t):
        for block in reversed(self.blocks):
            y = block.inverse(y, t)
        return y
class GraphFlowNF(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dims,
        num_layers,
        nodes: int,
        activation='ReLU',
        final_activation=None,
        time_net=None,
        time_hidden_dim=None,
        n_power_iterations=5,
        invertible=True,
        **kwargs
    ):
        super().__init__()
        blocks = []
        for _ in range(num_layers):
            blocks.append(GraphFlowBlock(dim, hidden_dims, activation, final_activation, time_net,
                                          time_hidden_dim, n_power_iterations, invertible))
        self.blocks = nn.ModuleList(blocks)

    def forward(self, nodes, x, h, t):
        for block in self.blocks:
            x = block(x, h, t)
        return x



class GraphFlowBlock(nn.Module):
    def __init__(self, dim, hidden_dims, activation, final_activation, time_net,
                 time_hidden_dim, n_power_iterations, invertible=True, **kwargs):
        super().__init__()
        self.invertible = invertible
        wrapper = None

        if invertible:
            wrapper = lambda layer: torch.nn.utils.spectral_norm(layer, n_power_iterations=n_power_iterations)
        dim = 37
        self.net = MLP(dim * 2 + 1, hidden_dims, dim, activation, final_activation, wrapper_func=wrapper)
        self.net2 = MLP(dim + 1, [32,64,32], dim, activation, final_activation, wrapper_func=wrapper)
        self.time_net = getattr(st.net, time_net)(dim, hidden_dim=time_hidden_dim)
    def forward(self, x, h, t):
        t_output = self.time_net(t)
        h = h.view(x.shape)
        x_output = x + t_output * (torch.mul(self.net(torch.cat([x, h, t], -1)), self.net2(torch.cat([x, t], -1))))
        return x_output

