import torch
import torch.nn as nn
from torch_scatter import scatter_sum
from .feature_propagation import FeaturePropagator
from .processor import RevProcessor

class MLP(nn.Module):
    ''' Base MLP class'''
    def __init__(self, in_dim, layer_dim_list, norm_type='LayerNorm', activation_type='ReLU'):
        super().__init__()
        # get the type of activation
        if activation_type is not None:
            assert (activation_type in ['ReLU', 'LeakyReLU', 'Tanh'])
            activation = getattr(nn, activation_type)
        else:
            activation = nn.Sequential

        # get the the dims of the fully connected layer
        dim_list = [in_dim] + layer_dim_list
        fc_layers = []
        for i in range(len(dim_list)-2):
            fc_layers += [nn.Linear(dim_list[i], dim_list[i+1]), activation()]

        # add the output layer without activation
        fc_layers += [nn.Linear(dim_list[-2], dim_list[-1])]

        # get the normalization type to add to the output of the MLP
        if norm_type is not None:
            assert (norm_type in ['LayerNorm','BatchNorm', 'GraphNorm', 'InstanceNorm', 'MessageNorm'])
            norm_layer = getattr(nn, norm_type)
            fc_layers.append(norm_layer(dim_list[-1]))

        # init the fully connected layers
        self.__fcs = nn.Sequential(*fc_layers)

    def forward(self, x):
        return self.__fcs(x)


class Encoder(nn.Module):
    ''' Encoder class, uses MLPs for each graph attribute'''
    def __init__(self, **kwargs):
        super().__init__()

        assert {'node_feature_dim', 'node_enc_mlp_layers', 'node_latent_dim',
                'edge_feature_dim', 'edge_enc_mlp_layers', 'edge_latent_dim',
                'glob_enc_mlp_layers','glob_latent_dim'}.issubset(kwargs)

        # initialize the node, edge and global params encoder MLPs
        self.node_encoder = MLP(kwargs['node_feature_dim']+kwargs['glob_latent_dim'],
                                kwargs['node_enc_mlp_layers'] + [kwargs['node_latent_dim']],
                                activation_type='ReLU', norm_type='LayerNorm')
        self.edge_encoder = MLP(kwargs['edge_feature_dim'],
                                kwargs['edge_enc_mlp_layers'] + [kwargs['edge_latent_dim']],
                                activation_type='ReLU', norm_type='LayerNorm')
        self.global_encoder = MLP(kwargs['glob_feature_dim'],
                                  kwargs['glob_enc_mlp_layers'] + [kwargs['glob_latent_dim']],
                                activation_type='ReLU', norm_type='LayerNorm')

    def forward(self, x, edge_attr, globals, batch):
        globals_enc = self.global_encoder(globals)
        x_enc = self.node_encoder(torch.cat([x, globals_enc[batch]], dim=1))
        edge_attr_enc = self.edge_encoder(edge_attr)

        return x_enc, edge_attr_enc, globals_enc

class Decoder(nn.Module):
    ''' Decoder class, uses MLPs for each graph attribute'''
    def __init__(self, **kwargs):
        super().__init__()

        assert {'node_latent_dim', 'node_dec_mlp_layers', 'node_out_dim',
                'glob_latent_dim', 'glob_dec_mlp_layers', 'glob_out_dim'}.issubset(kwargs)

        # initialize the node and global params decoder MLPs, without any normalization
        self.node_decoder = MLP(kwargs['node_latent_dim'],
                                kwargs['node_dec_mlp_layers'] + [kwargs['node_out_dim']],
                                activation_type='ReLU', norm_type=None)
        self.glob_decoder = MLP(kwargs['glob_latent_dim'],
                                kwargs['glob_dec_mlp_layers'] + [kwargs['glob_out_dim']],
                                activation_type='ReLU', norm_type=None)

    def forward(self, x, globals):
        x = self.node_decoder(x)
        globals = self.glob_decoder(globals)
        return x, globals


class FRGNN(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        assert {'encoder_settings', 'processor_settings', 'decoder_settings', 'fp_steps',
                'glob_loss_factor', 'div_loss_factor', 'noise_sigma'}.issubset(kwargs)

        # init the feature propagator
        self.feature_propagator = FeaturePropagator(num_iterations=kwargs['fp_steps'])

        # init the encoder, processor and decoder
        self.encoder = Encoder(**kwargs, **kwargs['encoder_settings'])
        self.processor = RevProcessor(**kwargs, **kwargs['processor_settings'])
        self.decoder = Decoder(**kwargs, **kwargs['decoder_settings'])

        # init the loss factors
        self.glob_loss_factor = kwargs['glob_loss_factor']
        self.div_loss_factor = kwargs['div_loss_factor']

        # noise control
        self.noise_sigma = kwargs['noise_sigma']

    def forward(self, data):
        assert all(hasattr(data, attr) for attr in ['x', 'known_feature_mask', 'edge_index', 'edge_attr', 'batch', 'pos'])

        # first propagate missing features
        filled_features = self.feature_propagator.propagate(x=data.x.clone(), edge_index=data.edge_index, mask=data.known_feature_mask)
        data.x = torch.where(data.known_feature_mask, data.x, filled_features)

        # apply noise to nodes during training
        if self.training and self.noise_sigma != 0.0:
            noise = torch.randn_like(data.x)*self.noise_sigma
            data.x = ((noise + data.x).detach() - data.x).detach() + data.x #trick to allow gradient pass through as if no noise

        # encode the mesh nodes, edges and globals to their latent forms
        data.x, data.edge_attr, data.globals = self.encoder(data.x, data.edge_attr, data.globals, data.batch)

        # message passing network to process encoded graph
        data.x = self.processor(data.x, data.edge_index, data.edge_attr)

        # decode the graph after each group to the original space
        data.x, data.globals = self.decoder(data.x, data.globals)
        return data

    def compute_loss(self, data):
        node_loss = nn.functional.mse_loss(data.x, data.y)
        glob_loss = nn.functional.mse_loss(data.globals, data.globals_y)

        if self.div_loss_factor != 0.0:
            div_loss = self.compute_div_loss(data)
            return node_loss + self.glob_loss_factor * glob_loss + self.div_loss_factor * div_loss
        else:
            return node_loss + self.glob_loss_factor * glob_loss

    def compute_div_loss(self, data):
        # get the velocity for each first node of all edges
        x = data.x * data.normalization_values[:, 1] + data.normalization_values[:, 0]
        u = x[data.edge_index[0, :], 1:3]
        # calculate for each edge the flux j= U dot edge_rd x edge_s
        j = torch.bmm(u.view(data.edge_index.shape[1], 1, -1), data.edge_rd.view(data.edge_index.shape[1], -1, 1))
        j = j.squeeze() * data.edge_s.squeeze()
        # for each node sum all the sender edge fluxes
        div = scatter_sum(j, data.edge_index[0, :], dim=0)

        # get divergence only for fluid nodes
        div = div[data.node_type == 0]

        return nn.functional.mse_loss(div, torch.zeros_like(div))
