import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
#from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import pickle


class Net_S(nn.Module):
    def __init__(self,dimX,dimY,H):
        super(Net_S, self).__init__()
        
        self.fc1 = nn.Linear(dimX+dimY, H)
        self.fc2 = nn.Linear(H, H)
        self.fc3 = nn.Linear(H, 1)

    def forward(self, x, y):
        z = torch.cat((x, y), 1)
        h1 = F.relu(self.fc1(z))
        h2 = F.relu(self.fc2(h1)) # self.fc2(h1) #
        h3 = self.fc3(h2)
        return h3

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def mine_criterion(pred_xy, pred_x_y):
    ret = torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y)))
    loss = -ret # maximize
    return loss

def train(trainloader, epoch,  net, optimizers):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #Number of slices: n_slice = m in the paper
    #Parallel neural networks, one for each slice - quite likely a faster way to set this up
    
    net.train()
    train_loss = 0

    for batch_idx, (X, Y) in enumerate(trainloader):
  
        
        batchsz = len(X)

        x_sample = X
        y_sample = Y


        # Do shuffling of one argument (something that the MINE estimator wants)
        ysh = np.random.permutation(Y.cpu().numpy())
        y_shuffle= (torch.from_numpy(ysh).type(torch.FloatTensor)).to(device)

     
            
        #More MINE things
        pred_xy = net(x_sample, y_sample)
        pred_x_y = net(x_sample, y_shuffle)

        loss = mine_criterion(pred_xy, pred_x_y)
        #print(loss)

        train_loss += loss
        # Update the slice
        net.zero_grad()
        loss.backward()
        optimizers.step()
    

    #print(batch_idx)
    return train_loss/(batch_idx+1)


def mine(trainloader,net, n_epoch,  lr):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    
    optimizers = torch.optim.Adam(list(net.parameters()), lr=lr)

    train_loss = np.zeros((n_epoch))
    for epoch in range(n_epoch):
        
        batchsz = trainloader.batch_size
        
        #for slice_ix in ix[:min(n_slice,10*batchsz)]:
        train_loss[epoch] = train(trainloader,epoch,net,optimizers)

        if epoch % 10 == 0:
            print('Epoch: %d' % epoch, 'train_loss: %.3f' % train_loss[epoch])
        if (epoch + 1) % 100 == 0:
            lr = lr/5
            
            optimizers = torch.optim.Adam(list(net.parameters()), lr=lr)


    return -np.mean((train_loss[-min(20,n_epoch):]))  #this is the MI

class GaussianDataset(Dataset):
    def __init__(self, x,y, transform=None):

        self.data = torch.from_numpy(y).float().to(device)
        self.target = torch.from_numpy(x).float().to(device)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        return x,y
    def __len__(self):
        return len(self.data)

if __name__ == "__main__":
    lr_re = 2e-4
    lr_pa = 1e-3 # 0.001. Could be higher
    batch_size = 100
    n_epoch = 200
    dim = 5
    n = 5000

    #Data
    X = np.random.randn(n,dim)
    Y = np.random.randn(n,dim)
    #make loader    
    trainset = GaussianDataset(X,Y)

    trainloader = DataLoader(
                    trainset,
                    batch_size=batch_size,
                    shuffle=True,
                    num_workers=0
                )

    #Parallel
    H = 30*dim

    net = Net_S(dim,H)
    if torch.cuda.is_available():
        net.cuda()
                    
    MI = mine(trainloader, net, n_epoch,  lr_pa)
    print(MI)