import numpy as np

import torch
import torch.nn as nn
#from jutils import *

## cubic
# lowersize = 40
# hiddensize = 6

## Gaussian
# lowersize = 20
# hiddensize = 8

## club vs l1out
lowersize = 40
hiddensize = 8


class CLUB(nn.Module):  # CLUB: Mutual Information Contrastive Learning Upper Bound
    def __init__(self, x_dim, y_dim, lr=1e-3, beta=0):
        super(CLUB, self).__init__()
        self.hiddensize = y_dim
        self.version = 0
        self.p_mu = nn.Sequential(nn.Linear(x_dim, self.hiddensize),
                                  nn.ReLU(),
                                  nn.Linear(self.hiddensize, y_dim))

        self.p_logvar = nn.Sequential(nn.Linear(x_dim, self.hiddensize),
                                      nn.ReLU(),
                                      nn.Linear(self.hiddensize, y_dim),
                                      nn.Tanh())

        self.optimizer = torch.optim.Adam(self.parameters(), lr)
        self.beta = beta

    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar

    def mi_est_sample(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)

        sample_size = x_samples.shape[0]
        random_index = torch.randint(sample_size, (sample_size,)).long()

        positive = - (mu - y_samples) ** 2 / 2. / logvar.exp()
        negative = - (mu - y_samples[random_index]) ** 2 / 2. / logvar.exp()
        upper_bound = (positive.sum(dim=-1) - negative.sum(dim=-1)).mean()
        # return upper_bound/2.
        return upper_bound

    def mi_est(self, x_samples, y_samples):  # [nsample, 1]
        mu, logvar = self.get_mu_logvar(x_samples)

        positive = - (mu - y_samples) ** 2 / 2. / logvar.exp()

        prediction_1 = mu.unsqueeze(1)  # [nsample,1,dim]
        y_samples_1 = y_samples.unsqueeze(0)  # [1,nsample,dim]
        negative = - ((y_samples_1 - prediction_1) ** 2).mean(dim=1) / 2. / logvar.exp()  # [nsample, dim]
        return (positive.sum(dim=-1) - negative.sum(dim=-1)).mean()
        # return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean(), positive.sum(dim = -1).mean(), negative.sum(dim = -1).mean()

    def loglikeli(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)

        # return -1./2. * ((mu - y_samples)**2 /logvar.exp()-logvar ).sum(dim=1).mean(dim=0)
        return (-(mu - y_samples) ** 2 / logvar.exp() - logvar).sum(dim=1).mean(dim=0)

    def update(self, x_samples, y_samples):
        if self.version == 0:
            self.train()
            loss = - self.loglikeli(x_samples, y_samples)

            self.optimizer.zero_grad()
            loss.backward(retain_graph=True)
            self.optimizer.step()

            # self.eval()
            return self.mi_est_sample(x_samples, y_samples) * self.beta

        elif self.version == 1:
            self.train()
            x_samples = torch.reshape(x_samples, (-1, x_samples.shape[-1]))
            y_samples = torch.reshape(y_samples, (-1, y_samples.shape[-1]))

            loss = -self.loglikeli(x_samples, y_samples)

            self.optimizer.zero_grad()
            loss.backward(retain_graph=True)
            self.optimizer.step()
            upper_bound = self.mi_est_sample(x_samples, y_samples) * self.beta
            # self.eval()
            return upper_bound
