import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import math
from scipy.stats import entropy
from torch.distributions.multivariate_normal import MultivariateNormal

class VAE(nn.Module):
    def __init__(self, input_dim=(1, 28, 28), latent_dim=2):
        super(VAE, self).__init__()

        self.input_dim = input_dim
        self.latent_dim = latent_dim

        # Encoder network X
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * latent_dim)
        )
        
        self.project_hy = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 2 * 2 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
        
        # Encoder network Y
        self.encoder_Y = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 2 * 2 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + 128, 512),
            nn.ReLU(),
            nn.Linear(512, 128 * 7 * 7),
            nn.ReLU(),
            nn.Unflatten(1, (128, 7, 7)),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Dropout(p=0.5), # REGULARIZER
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = torch.chunk(h, 2, dim=1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z, y):
        h_y = self.encoder_Y(y)
        z_m = torch.cat((z,h_y), dim=1)
        x = self.decoder(z_m)
        return x, h_y
    
    def score_y(self, y):
        hy_pj = self.project_hy(y)
        return hy_pj
    
    def score_yz(self, y, z):
        
        hy_pj = self.project_hy(y)
        hy_pjm = torch.cat((z, hy_pj), dim=1)
        score_yz = self.score(hy_pjm)
        return score_yz
    
    def forward(self, x, y):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat, h_y = self.decode(z, y)
        
        #hy_pj = self.score_y(y)
        
        return x_hat, mu, logvar, z, h_y
    
class NCEClassifier(nn.Module):
    def __init__(self, latent_dim=128, u_dim=4, nf=128):
        super(NCEClassifier, self).__init__()

        #self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.sigmoid = nn.Sigmoid()
        # Decoder network
        self.network = nn.Sequential(
            nn.Linear(latent_dim + u_dim, nf),
            nn.LeakyReLU(),
            nn.Linear(nf, nf),
            nn.LeakyReLU(),
            nn.Linear(nf, nf),
            nn.LeakyReLU(),
            nn.Linear(nf, nf),
            nn.LeakyReLU(),
            nn.Linear(nf, 1),
        )
        
        self.project_hy = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 2 * 2 * 4, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
        
    def forward(self, y, u):
        #projection step:
        y = self.project_hy(y)
        #y = y.repeat(len(u),1)
        inp = torch.cat((y,u), dim=1)
        #print (inp.shape)
        out = self.network(inp)
        prob = self.sigmoid(out)
        return prob, out
    
    def forward_2(self, y, u):
        #projection step:
        y = self.project_hy(y)
        y = y.repeat(len(u),1)
        inp = torch.cat((y,u), dim=1)
        #print (inp.shape)
        out = self.network(inp)
        prob = self.sigmoid(out)
        return prob, out

def generate_M(B=128, N=1024, L=2, device='cuda'): #log_2(L) bits only
    #random_tensor = torch.cuda.FloatTensor(B, N).random_(L).long()
    M = np.arange(L)
    M = np.repeat(M, int(N/L))
    M = torch.from_numpy(M).to(device) #torch.Tensor(M).cuda()
    """M = torch.arange(L).cuda()
    M = M.repeat(int(N/L))"""
    M = M[None, ...].long()
    return M #random_tensor