import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from model.ivp_solvers.mask import get_mask
from model.ivp_solvers.mlp import MLP


class ContinuousAffineCoupling(nn.Module):
    """
    Continuous affine coupling layer. If `dim = 1`, set `mask = 'none'`.
    Similar to `Coupling` but applies only an affine transformation
    which here depends on time `t` such that it's identity map at `t = 0`.

    Args:
        latent_net (Type[nn.Module]): Inputs concatenation of `x` and `t` (and optionally
            `latent`) and outputs affine transformation parameters (size `2 * dim`)
        time_net (Type[stribor.net.time_net]): Time embedding with the same output
            size as `latent_net`
        mask (str): Mask name from `stribor.util.mask`
    """
    def __init__(self, latent_net, latent_net2, time_net, mask, **kwargs):
        super().__init__()

        self.latent_net = latent_net
        self.latent_net2 = latent_net2
        self.latent_net_h = kwargs.get('latent_net_h')
        self.mask_func = get_mask(mask)
        self.time_net = time_net
        self.merge_scale = kwargs.get('merge_scale')
        self.merge_shift = kwargs.get('merge_shift')

    def get_mask(self, x):
        return self.mask_func(x.shape[-1]).expand_as(x).to(x)

    def forward(self, x, h, t, latent=None, reverse=False, **kwargs):
        """
        Args:
            x (tensor): Input with shape (..., dim)
            t (tensor): Time input with shape (..., 1)
            h: graph embedding
            latent (tensor): Conditioning vector with shape (..., latent_dim)
            reverse (bool, optional): Whether to calculate inverse. Default: False

        Returns:
            y (tensor): Transformed input with shape (..., dim)
            ljd (tensor): Log-Jacobian diagonal with shape (..., dim)
        """

        mask = self.get_mask(x)
        z = torch.cat([x * 0 if x.shape[-1] == 1 else x * mask, t], -1)
        if latent is not None:
            z = torch.cat([z, latent], -1)
        print(z.shape)
        scale, shift = self.latent_net(z).chunk(2, dim=-1)
        t_scale, t_shift = self.time_net(t).chunk(2, dim=-1)

        if h is not None:
            mask2 = self.get_mask(h)
            z2 = torch.cat([h * 0 if h.shape[-1] == 1 else h * mask2, t], -1)
            if latent is not None:
                z2 = torch.cat([z2, latent], -1)
            scale2, shift2 = self.latent_net_h(z2).chunk(2, dim=-1)
            merged_scale = self.merge_scale(torch.cat([scale, scale2], -1))
            merged_shift = self.merge_shift(torch.cat([shift, shift2], -1))
        else:
            merged_scale, merged_shift = scale, shift

        if reverse:
            y = (x - shift * t_shift) * torch.exp(-scale * t_scale)
        else:
            y = x * torch.exp(merged_scale * t_scale) + merged_shift * t_shift

        y = y * (1 - mask) + x * mask
        return y

class ContinuousAffineCouplingGNN(nn.Module):
    """
    Continuous affine coupling layer. If `dim = 1`, set `mask = 'none'`.
    Similar to `Coupling` but applies only an affine transformation
    which here depends on time `t` such that it's identity map at `t = 0`.

    Args:
        latent_net (Type[nn.Module]): Inputs concatenation of `x` and `t` (and optionally
            `latent`) and outputs affine transformation parameters (size `2 * dim`)
        time_net (Type[stribor.net.time_net]): Time embedding with the same output
            size as `latent_net`
        mask (str): Mask name from `stribor.util.mask`
    """
    def __init__(self,dim, hidden_dims, n_layers, nodes, latent_net, time_net, mask, **kwargs):
        super().__init__()

        self.latent_net = latent_net
        self.latent_net_h = kwargs.get('latent_net_h')
        self.mask_func = get_mask(mask)
        self.time_net = time_net
        self.merge_scale = kwargs.get('merge_scale')
        self.merge_shift = kwargs.get('merge_shift')
        self.dropout = nn.Dropout(p=0.2) 
        hidden_dims = hidden_dims[0]
        self.lin_n = nn.Sequential(nn.Linear(dim, hidden_dims),
                                   nn.Linear(hidden_dims, hidden_dims),
                                   )
        self.lin_r = nn.Sequential(nn.Linear(hidden_dims, hidden_dims),
                                   nn.Linear(hidden_dims, dim * 2),
                                   )
        self.act = nn.ReLU()
        
    def get_mask(self, x):
        return self.mask_func(x.shape[-1]).expand_as(x).to(x)

    def forward(self, nodes, x, h, t, adj = None,latent=None, reverse=False,**kwargs):
        
        
        z = torch.cat([x, t], -1)
        if latent is not None:
            z = torch.cat([z, latent], -1)
        scale, shift = self.latent_net(z).chunk(2, dim=-1)
        t_scale, t_shift = self.time_net(t).chunk(2, dim=-1)

        h_res = h
        z2 = h_res

        if latent is not None:
            z2 = torch.cat([z2, latent], -1)
    
        z_n = self.lin_n(torch.einsum('nlakd,kj->nlajd', z2, adj)) 
        z_r = self.lin_r(self.act(z_n))
        scale2,shift2 = z_r.chunk(2,dim=-1)

        scale2 = scale.contiguous().view(h.shape[0],h.shape[1],-1)
        shift2 = shift.contiguous().view(h.shape[0],h.shape[1],-1)


        if reverse:
            y = (x - shift * t_shift) * torch.exp(-scale * t_scale)
        else:
            print(scale.shape,scale2.shape)
            y = x * torch.exp(torch.mul(scale,scale2) * t_scale) + torch.mul(shift,shift2) * t_shift
            z_r = z_r.view(h.shape[0],h.shape[1],-1)
            h_output = self.latent_net_h(z_r)
        
        return y, h_output
    



class ContinuousAffineCouplingGNN2(nn.Module):
    """
    Continuous affine coupling layer. If `dim = 1`, set `mask = 'none'`.
    Similar to `Coupling` but applies only an affine transformation
    which here depends on time `t` such that it's identity map at `t = 0`.

    Args:
        latent_net (Type[nn.Module]): Inputs concatenation of `x` and `t` (and optionally
            `latent`) and outputs affine transformation parameters (size `2 * dim`)
        time_net (Type[stribor.net.time_net]): Time embedding with the same output
            size as `latent_net`
        mask (str): Mask name from `stribor.util.mask`
    """
    def __init__(self,dim, hidden_dims, n_layers, nodes, latent_net, time_net, mask, **kwargs):
        super().__init__()

        self.latent_net = latent_net

        self.mask_func = get_mask(mask)
        self.time_net = time_net

        self.dropout = nn.Dropout(p=0.4) 
        hidden_dims = hidden_dims[0]
        self.lin_n = nn.Sequential(nn.Linear(dim+1, hidden_dims),
                                   nn.Linear(hidden_dims, hidden_dims),
                                   )

        self.lin_r = nn.Sequential(nn.Linear(hidden_dims, hidden_dims),
                                   nn.Linear(hidden_dims, dim * 2),
                                   )
        
        self.act = nn.ReLU()

    def get_mask(self, x):
        return self.mask_func(x.shape[-1]).expand_as(x).to(x)

    def forward(self, nodes, x, h, t, adj = None, latent=None, reverse=False, **kwargs):
        """
        Args:
            x (tensor): Input with shape (..., dim)
            t (tensor): Time input with shape (..., 1)
            h: graph embedding
            latent (tensor): Conditioning vector with shape (..., latent_dim)
            reverse (bool, optional): Whether to calculate inverse. Default: False

        Returns:
            y (tensor): Transformed input with shape (..., dim)
            ljd (tensor): Log-Jacobian diagonal with shape (..., dim)
        """

        h_res = h.view(h.shape[0], h.shape[1], nodes, h.shape[-1] // nodes)
        mask = self.get_mask(h_res)

        z = torch.cat([h_res * 0 if h_res.shape[-1] == 1 else h_res * mask, t.unsqueeze(2).expand(-1, -1, nodes, -1)], -1)
        if latent is not None:
            z = torch.cat([z, latent], -1)

        z = self.dropout(z)
        z_n = self.lin_n(torch.einsum('nlkd,kj->nljd', z, adj)) 
        z_r = self.lin_r(self.act(z_n))
        scale,shift = z_r.chunk(2,dim=-1)

        t_scale, t_shift = self.time_net(t).chunk(2, dim=-1)
        scale = scale.contiguous().view(h.shape[0],h.shape[1],-1)
        shift = shift.contiguous().view(h.shape[0],h.shape[1],-1)

        h_output = shift.view(h.shape[0],h.shape[1],-1)
        if reverse:
            y = (h - shift * t_shift) * torch.exp(-scale * t_scale)
        else:
           y = x * torch.exp(scale * t_scale) + shift * t_shift

        mask = mask.view(h.shape[0],h.shape[1],-1)
        y = y * (1 - mask) + x * mask
        
        return y,h_output