from typing import List, Optional

import torch.nn as nn
import libs.stribor as st
from torch import Tensor
from torch.nn import Module
from model.ivp_solvers.gnn_flow import GCNFlowNF, GraphFlowNF
from model.ivp_solvers.h_coupling_flow import ContinuousAffineCouplingGNN
from model.ivp_solvers.mlp import MLP
from model.ivp_solvers.gnn import GNN
class Flow(nn.Module):

    def __init__(self, base_dist=None, transforms=[]):
        super().__init__()
        self.base_dist = base_dist
        self.transforms = nn.ModuleList(transforms)
    def forward(self, nodes, x, h=None, latent=None, mask=None, t=None, adj=None, reverse=False, **kwargs):

        transforms = self.transforms[::-1] if reverse else self.transforms
        _mask = 1 if mask is None else mask
        for f in transforms:
            x,h = f.forward(nodes, x * _mask, h, latent=latent, mask=mask, t=t, adj=adj, **kwargs)
        return x

    def log_prob(self, x, **kwargs):

        if self.base_dist is None:
            raise ValueError('Please define `base_dist` if you need log-probability')
        x, log_jac_diag = self.inverse(x, **kwargs)
        log_prob = self.base_dist.log_prob(x) + log_jac_diag.sum(-1)
        return log_prob.unsqueeze(-1)

    def sample(self, num_samples, latent=None, mask=None, **kwargs):
       
        if self.base_dist is None:
            raise ValueError('Please define `base_dist` if you need sampling')
        if isinstance(num_samples, int):
            num_samples = (num_samples,)

        x = self.base_dist.rsample(num_samples)
        x, log_jac_diag = self.forward(x, **kwargs)
        return x

class CouplingFlow(Module):
  

    def __init__(
        self,
        dim: int,
        n_layers: int,
        hidden_dims: List[int],
        time_net: Module,
        time_hidden_dim: Optional[int] = None,
        **kwargs
    ):
        super().__init__()

        transforms = []
        for i in range(n_layers):
            transforms.append(st.ContinuousAffineCoupling(
                latent_net=st.net.MLP(dim + 1, hidden_dims, 2 * dim),
                time_net=getattr(st.net, time_net)(
                    2 * dim, hidden_dim=time_hidden_dim),
                mask='none' if dim == 1 else f'ordered_{i % 2}'))

        self.flow = st.Flow(transforms=transforms)

    def forward(
        self,
        x: Tensor,  
        t: Tensor,  
        t0: Optional[Tensor] = None,
    ) -> Tensor: 

        if x.shape[-2] == 1:

            x = x.repeat_interleave(t.shape[-2], dim=-2)


        if t0 is not None:
            x = self.flow.inverse(x, t=t0)[0]

        return self.flow(x, t=t)[0]


class ResNetFlow(Module):
   
    def __init__(
        self,
        dim: int,
        n_layers: int,
        hidden_dims: List[int],
        time_net: str,
        time_hidden_dim: Optional[int] = None,
        invertible: Optional[bool] = True,
        **kwargs
    ):
        super().__init__()

        layers = []
        for _ in range(n_layers):
            layers.append(st.net.ResNetFlow(
                dim,
                hidden_dims,
                n_layers,
                activation='ReLU',
                final_activation=None,
                time_net=time_net,
                time_hidden_dim=time_hidden_dim,
                invertible=invertible
            
            ))

        self.layers = nn.ModuleList(layers)

    def forward(
        self,
        x: Tensor, 
        t: Tensor,  
    ) -> Tensor:  
        

        if x.shape[-2] == 1:
            x = x.repeat_interleave(t.shape[-2], dim=-2) 
        
        for layer in self.layers:
            x = layer(x, t)

        return x

class GNNFlow(Module):

    def __init__(self,
                 dim:int, 
                 nodes: int,
                 n_layers:int, 
                 hidden_dims,
                 time_net: str,
                 time_hidden_dim: Optional[int],
                 invertible: Optional[bool] = True,
                 **kwargs
                        ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(GCNFlowNF(dim,
                                    hidden_dims,
                                    n_layers,
                                    nodes = nodes,
                                    activation='ReLU',
                                    final_activation=None,
                                    time_net=time_net,
                                    time_hidden_dim=time_hidden_dim,
                                    invertible=invertible))
        self.layers = nn.ModuleList(layers)
    def set_graph(self, g):
        for l in self.layers:
            l.g = g

    def forward(self, nodes, x, h, t_start, t_end, adj):
        if h is not None:
            if h.shape[-3] == 1:
                h = h.repeat_interleave(t_end.shape[-2], dim=-3)
        if x.shape[-2] == 1:
            x = x.repeat_interleave(t_end.shape[-2], dim=-2)
        
        for layer in self.layers:
            x,h = layer(nodes, x, h, t_start, t_end, adj)

        return x


class InverGNNFlow(Module):
    def __init__(
        self,
        dim:int,
        nodes: int,
        n_layers:int,
        hidden_dims,
        time_net: str,
        time_hidden_dim: Optional[int],
        invertible: Optional[bool] = True,
        **kwargs
    ):
        super().__init__()

        transforms = []
        for i in range(n_layers):

            transforms.append(ContinuousAffineCouplingGNN(
                dim,
                hidden_dims,
                n_layers,
                nodes = nodes,
                latent_net=MLP(dim * nodes + 1, hidden_dims, 2 * dim * nodes),
                latent_net_h=MLP(dim * nodes * 2, hidden_dims, dim*nodes),
                merge_scale=MLP(dim*nodes * 2, hidden_dims, dim*nodes),
                merge_shift=MLP(dim*nodes * 2, hidden_dims, dim*nodes),
                time_net=getattr(st.net, time_net)(2 * dim * nodes, hidden_dim=time_hidden_dim),
                mask='none' if dim == 1 else f'ordered_{i % 2}'))

        self.flow = Flow(transforms=transforms)
    def forward(
        self,
        nodes,
        x: Tensor, 
        h: Tensor,
        t: Tensor, 
        adj,
        t0: Optional[Tensor] = None,
    ) -> Tensor: 
        if h is not None:
            if h.shape[-3] == 1:
                h = h.repeat_interleave(t.shape[-2], dim=-3)
        if x.shape[-2] == 1:
            x = x.repeat_interleave(t.shape[-2], dim=-2)

        return self.flow(nodes, x, h, t=t, adj=adj)
    

class GraphFlow(Module):

    def __init__(
        self,
        dim: int,
        nodes,
        n_layers: int,
        hidden_dims: List[int],
        time_net: str,
        time_hidden_dim: Optional[int] = None,
        invertible: Optional[bool] = True,
        **kwargs
    ):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(GraphFlowNF(
                dim,
                hidden_dims,
                n_layers,
                nodes = nodes,
                activation='ReLU',
                final_activation=None,
                time_net=time_net,
                time_hidden_dim=time_hidden_dim,
                invertible=invertible
            ))

        self.layers = nn.ModuleList(layers)

    def forward(self, nodes, x, h, t, adj):
        if h is not None:
            if h.shape[-3] == 1:
                h = h.repeat_interleave(t.shape[-2], dim=-3)
        if x.shape[-2] == 1:
            x = x.repeat_interleave(t.shape[-2], dim=-2)
        for layer in self.layers:
            x = layer(nodes, x, h, t)

        return x
