import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import math


class FF(nn.Module):
    def __init__(self, args, dim_input, dim_hidden, dim_output, dropout_rate=0):
        super(FF, self).__init__()
        assert (not args.ff_residual_connection) or (dim_hidden == dim_input)
        self.residual_connection = args.ff_residual_connection
        self.num_layers = args.ff_layers
        self.layer_norm = args.ff_layer_norm
        self.activation = args.ff_activation
        self.stack = nn.ModuleList()
        for l in range(self.num_layers):
            layer = []

            if self.layer_norm:
                layer.append(nn.LayerNorm(dim_input if l == 0 else dim_hidden))

            layer.append(nn.Linear(dim_input if l == 0 else dim_hidden,
                                   dim_hidden))
            layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[self.activation])
            layer.append(nn.Dropout(dropout_rate))

            self.stack.append(nn.Sequential(*layer))

        self.out = nn.Linear(dim_input if self.num_layers < 1 else dim_hidden,
                             dim_output)

    def forward(self, x):
        for layer in self.stack:
            x = x + layer(x) if self.residual_connection else layer(x)
        return self.out(x)


class InfoNCEPointwise(nn.Module):
    def __init__(self, args, zc_dim, zd_dim):
        super(InfoNCEPointwise, self).__init__()
        # self.net = MINet(args, zc_dim + zd_dim)
        self.net = FF(args, zc_dim + zd_dim, zc_dim, 1)

    def forward(self, z_c, z_d, average=True):  # samples have shape [sample_size, dim]
        # shuffle and concatenate
        sample_size = z_d.shape[0]

        zc_tile = z_c.unsqueeze(0).repeat((sample_size, 1, 1))  # [sample_size, sample_size, c]
        zd_tile = z_d.unsqueeze(1).repeat((1, sample_size, 1))  # [sample_size, sample_size, c]

        T0 = self.net(torch.cat([z_c, z_d], dim=-1))
        T1 = self.net(torch.cat([zc_tile, zd_tile], dim=-1))  # [sample_size, sample_size, 1]

        lower_bound = T0.mean(dim=1) - (T1.logsumexp(dim=1).mean(dim=-1) - np.log(sample_size))
        # lower_bound has size [sample_size, ]
        return lower_bound, 0., 0.

    def learning_loss(self, z_c, z_d):
        # learning_loss has size [sample_size, ]
        return - self(z_c, z_d)[0]