import torch
from torch.nn import ModuleList, Linear, Module
from torch.nn.functional import softmax, log_softmax
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn import MessagePassing, global_add_pool

# Manual Seed for Reproducibility
from utils.utils import gumbel_softmax

torch.manual_seed(0)

class ArgMax(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        y_soft = input.softmax(dim=-1)
        index = y_soft.max(-1, keepdim=True)[1]
        y_hard = torch.zeros_like(input, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)

        ctx.save_for_backward(y_soft, y_hard)
        return y_hard, y_soft

    @staticmethod
    def backward(ctx, grad_output, grad_out_y_soft):
        y_soft, y_hard = ctx.saved_tensors
        grad = grad_output * y_hard
        grad += grad_out_y_soft * y_soft
        return grad


def argmax(x):
    # Create a wrapper that only returns the first output
    return ArgMax.apply(x)[0]


class MLP(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(MLP, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x):
        for i, lin in enumerate(self.lins[:-1]):
            x = lin(x)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return x


class LinearSoftmax(Module):
    def __init__(self, in_channels, out_channels, gumbel=True, temperature=1.0, use_batch_norm=True):
        super(LinearSoftmax, self).__init__()
        self.__name__ = 'LinearSoftmax'
        self.lin1 = Linear(in_channels, out_channels)
        self.bn = torch.nn.BatchNorm1d(out_channels)
        self.argmax = False
        self.gumbel = gumbel
        self.softmax_temp = temperature
        self.beta = 0.0
        self.alpha = 1.0
        self.use_batch_norm = use_batch_norm

    def forward(self, x):
        x = self.lin1(x)
        if self.use_batch_norm:
            x = self.bn(x)
        if self.argmax:
            x_d = argmax(x)
        elif self.gumbel:
            x_d = gumbel_softmax(x, hard=True, tau=self.softmax_temp, beta=self.beta)
        else:
            x_d = softmax(x / self.softmax_temp, dim=-1)

        if np.random.random() > self.alpha and self.training:
            x = (x + x_d) / 2
        else:
            x = x_d
        return x


class MLPSoftmax(Module):
    def __init__(self, in_channels, out_channels, gumbel=True, temperature=1.0, hidden_units=16, dropout=0.0):
        super(MLPSoftmax, self).__init__()
        self.__name__ = 'LinearSoftmax'
        self.mlp = MLP(in_channels, hidden_units, out_channels, 2, dropout)
        self.gumbel = gumbel
        self.argmax = False
        self.beta = 0.0
        self.alpha = 1.0
        self.softmax_temp = temperature

    def forward(self, x):
        x = self.mlp(x)
        if self.argmax:
            x_d = argmax(x)
        elif self.gumbel:
            x_d = gumbel_softmax(x, hard=True, tau=self.softmax_temp, beta=self.beta)
        else:
            x_d = softmax(x / self.softmax_temp, dim=-1)

        if np.random.random() > self.alpha and self.training:
            x = (x + x_d) / 2
        else:
            x = x_d
        return x


class InputLayer(Module):
    def __init__(self, in_channels, out_channels, softmax_temp, gumbel=True, network='linear'):
        super(InputLayer, self).__init__()
        self.__name__ = 'FirstLayer'
        self.lin1 = Linear(in_channels, out_channels)
        self.gumbel = gumbel
        self.argmax = False
        self.beta = 0.0
        self.alpha = 1.0
        self.softmax_temp = softmax_temp

    def forward(self, x):
        if x is not torch.FloatTensor:
            x = x.float()
        x = self.lin1(x)
        if self.argmax:
            x_d = argmax(x)
        elif self.gumbel:
            x_d = gumbel_softmax(x, hard=True, tau=self.softmax_temp, beta=self.beta)
        else:
            x_d = softmax(x / self.softmax_temp, dim=-1)

        if np.random.random() > self.alpha and self.training:
            x = (x + x_d) / 2
        else:
            x = x_d
        return x


class PoolingLayer(Module):
    def __init__(self, in_channels, out_channels, network='linear', hidden_units=16, dropout=0.0):
        super(PoolingLayer, self).__init__()
        self.__name__ = 'PoolingLayer'
        self.lin2 = Linear(in_channels, out_channels)
        if network == 'mlp':
            self.lin2 = MLP(in_channels, hidden_units, out_channels, 2, dropout)

    def forward(self, x):
        x = self.lin2(x)
        return log_softmax(x, dim=-1)


class StoneAgeGNNLayer(MessagePassing):

    def __init__(self, in_channels, out_channels, bounding_parameter, hidden_units=16, dropout=0.0, gumbel=False,
                 temperature=1.0, index=0, network='linear', use_batch_norm=True):
        super().__init__(aggr='add')
        self.__name__ = 'stone-age-' + str(index)
        if network == 'mlp':
            self.linear_softmax = MLPSoftmax(in_channels, out_channels, gumbel, temperature, hidden_units=hidden_units,
                                             dropout=dropout)
        else:
            self.linear_softmax = LinearSoftmax(in_channels, out_channels, gumbel, temperature,
                                                use_batch_norm=use_batch_norm)
        self.bounding_parameter = bounding_parameter

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def aggregate(self, inputs, index, ptr, dim_size):
        message_sums = super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
        return torch.clamp(message_sums, min=0, max=self.bounding_parameter)

    def update(self, inputs, x):
        combined = torch.cat((inputs, x), 1)
        x = self.linear_softmax(combined)
        return x


class StoneAgeGNN(Module):
    def __init__(self, in_channels, out_channels, bounding_parameter, state_size, num_layers=1, gumbel=False,
                 softmax_temp=1.0, network='linear', use_pooling=True, skip_connection=False, use_batch_norm=True,
                 hidden_units=16, dropout=0.0):
        super().__init__()

        self.input = InputLayer(in_channels, state_size, gumbel=gumbel, softmax_temp=softmax_temp, network=network)
        self.initial_gumbel = gumbel
        self.output = PoolingLayer(state_size, out_channels, network=network, hidden_units=hidden_units,
                                   dropout=dropout)
        self.stone_age = ModuleList()
        self.num_layers = num_layers
        self.use_pooling = use_pooling
        self.skip_connection = skip_connection
        if skip_connection:
            self.output = PoolingLayer((num_layers + 1) * state_size, out_channels, network=network,
                                       hidden_units=hidden_units, dropout=dropout)
        for i in range(num_layers):
            self.stone_age.append(
                StoneAgeGNNLayer(state_size * 2,
                                 state_size,
                                 bounding_parameter=bounding_parameter,
                                 gumbel=gumbel,
                                 temperature=softmax_temp,
                                 index=i,
                                 network=network,
                                 use_batch_norm=use_batch_norm, hidden_units=hidden_units, dropout=dropout))

    def set_gumbel(self, gumbel):
        self.input.gumbel = gumbel
        for i in range(self.num_layers):
            layer = self.stone_age[i]
            layer.linear_softmax.gumbel = False if i == self.num_layers - 1 else gumbel

    def set_argmax(self, enabled):
        self.input.argmax = enabled
        for i in range(self.num_layers):
            layer = self.stone_age[i]
            layer.linear_softmax.argmax = enabled

    def set_softmax_temp(self, temperature):
        self.input.softmax_temp = temperature
        for i in range(self.num_layers):
            layer = self.stone_age[i]
            layer.linear_softmax.softmax_temp = temperature

    def set_beta(self, beta):
        self.input.beta = beta
        for i in range(self.num_layers):
            layer = self.stone_age[i]
            layer.linear_softmax.beta = beta

    def set_alpha(self, alpha):
        self.input.alpha = alpha
        for i in range(self.num_layers):
            layer = self.stone_age[i]
            layer.linear_softmax.alpha = alpha

    def forward(self, x, edge_index, batch=None):

        x = self.input(x)
        xs = [x]
        for layer in self.stone_age:
            x = layer(x, edge_index)
            xs.append(x)

        if self.use_pooling:
            x = global_add_pool(x, batch)
            xs = [global_add_pool(xi, batch) for xi in xs]
        if self.skip_connection:
            x = torch.cat(xs, dim=1)
        x = self.output(x)
        return x
