import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from .spectral_normalization import SpectralNorm


class PosLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False, *args, **kwargs):
        super().__init__(in_features, out_features, bias=bias)

    def forward(self, x):
        return F.linear(x, torch.abs(self.weight))


class MaxReLUPairwiseActivation(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.weights = nn.Parameter(torch.zeros(1, num_features))
        self.avg_pool = nn.AvgPool1d(2, 2)

    def forward(self, x):
        x = x.unsqueeze(1)
        max_component = F.max_pool1d(x, 2)
        relu_component = F.avg_pool1d(F.relu(x * F.softplus(self.weights)), 2)
        return torch.cat((max_component, relu_component), dim=-1).squeeze(1)


class FICNN(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 layers,
                 hidden_size,
                 activation='relu',
                 last_activation=False,
                 nonneg_constraint='abs'
                ):
        super().__init__()
        
        self.W_z = nn.ModuleList()
        self.W_y = nn.ModuleList()

        # first layer
        self.W_z.append(nn.Identity())    # just a placeholder for indexing
        self.W_y.append(nn.Linear(in_features, hidden_size, bias=True))

        #Middle Layers
        self.W_z.extend([PosLinear(hidden_size, hidden_size, bias=False)
                             for _ in range(layers - 1)])
        self.W_y.extend([nn.Linear(in_features, hidden_size, bias=True)
                             for _ in range(layers - 1)])

        #Final layers
        self.W_z.append(PosLinear(hidden_size, out_features, bias=False))
        self.W_y.append(nn.Linear(in_features, out_features, bias=True))

        self.layers = layers
        self.last_activation = last_activation

        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'softplus':
            self.activation = nn.Softplus()
        elif activation == 'maxrelu':
            self.activation = MaxReLUPairwiseActivation(hidden_size)

    def forward(self, y):
        z = self.W_y[0](y)
        z = self.activation(z)
        
        for i in range(1, len(self.W_z)):
            z = self.W_z[i](z) + self.W_y[i](y)
            if self.last_activation or i < len(self.W_z) - 1:
                z = self.activation(z)
        return z


class SpectralFICNN(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 layers,
                 hidden_size,
                 activation='relu',
                 last_activation=False,
                 nonneg_constraint='abs'
                ):
        super().__init__()
        
        self.W_z = nn.ModuleList()
        self.W_y = nn.ModuleList()

        # first layer
        self.W_z.append(nn.Identity())    # just a placeholder for indexing
        self.W_y.append(SpectralNorm(nn.Linear(in_features, hidden_size, bias=True)))

        #Middle Layers
        self.W_z.extend([SpectralNorm(PosLinear(hidden_size, hidden_size, bias=False))
                             for _ in range(layers - 1)])
        self.W_y.extend([SpectralNorm(nn.Linear(in_features, hidden_size, bias=True))
                             for _ in range(layers - 1)])

        #Final layers
        self.W_z.append(SpectralNorm(PosLinear(hidden_size, out_features, bias=False)))
        self.W_y.append(SpectralNorm(nn.Linear(in_features, out_features, bias=True)))

        self.layers = layers
        self.last_activation = last_activation

        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'softplus':
            self.activation = nn.Softplus()
        elif activation == 'maxrelu':
            self.activation = MaxReLUPairwiseActivation(hidden_size)

    def forward(self, y):
        z = self.W_y[0](y)
        z = self.activation(z)
        
        for i in range(1, len(self.W_z)):
            z = self.W_z[i](z) + self.W_y[i](y)
            if self.last_activation or i < len(self.W_z) - 1:
                z = self.activation(z)
        return z


class InvariantFICNN(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 layers,
                 hidden_size,
                 activation='relu',
                 last_activation=False,
                 nonneg_constraint='abs'
                ):
        super().__init__()
        
        self.W_z = nn.ModuleList()
        self.W_y = nn.ModuleList()

        # first layer
        self.W_z.append(nn.Identity())    # just a placeholder for indexing
        self.W_y.append(nn.Linear(1, hidden_size, bias=True))

        #Middle Layers
        self.W_z.extend([PosLinear(hidden_size, hidden_size, False)
                             for _ in range(layers - 1)])
        self.W_y.extend([nn.Linear(1, hidden_size, bias=True)
                             for _ in range(layers - 1)])

        #Final layers
        self.W_z.append(PosLinear(hidden_size, out_features, False))
        self.W_y.append(nn.Linear(1, out_features, bias=True))

        self.layers = layers
        self.last_activation = last_activation

        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'softplus':
            self.activation = nn.Softplus()
        elif activation == 'maxrelu':
            self.activation = MaxReLUPairwiseActivation(hidden_size)

    def forward(self, y):
        # ####
        # y = torch.zeros_like(y) + torch.mean(y, dim=1, keepdim=True)
        # #### 
        y = torch.mean(y, dim=1, keepdim=True)

        out = 0.
        for d in range(y.shape[1]):
            # from pdb import set_trace; set_trace()
            yi = y[:, d:d+1]
            z = self.W_y[0](yi)
            z = self.activation(z)
            
            for i in range(1, len(self.W_z)):
                z = self.W_z[i](z) + self.W_y[i](yi)
                if self.last_activation or i < len(self.W_z) - 1:
                    z = self.activation(z)

            out += z
        return out
