import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F

def VAE_LL_loss(Xbatch,Xest,logvar,mu):
    batch_size = Xbatch.shape[0]
    sse_loss = torch.nn.MSELoss(reduction = 'sum') # sum of squared errors
    KLD = 1./batch_size * -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    mse = 1./batch_size * sse_loss(Xest,Xbatch)
    auto_loss = mse + KLD
    return auto_loss, mse, KLD

class Encoder(nn.Module):
 
    def __init__(self, z_dim, c_dim,img_size):
        """
        Encoder initializer
        :param x_dim: dimension of the input
        :param z_dim: dimension of the latent representation
        :param M: number of transport operators
        """
        super(Encoder, self).__init__()
        
        self.model_enc = nn.Sequential(
            nn.Conv2d(int(c_dim), 64, 4, stride=2, padding=1),  
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 4, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ZeroPad2d((1,2,1,2)),
            nn.Conv2d(64, 64, 4, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )

        self.fc_mu = nn.Linear(int(64*img_size*img_size/16), z_dim)
        self.fc_var = nn.Linear(int(64*img_size*img_size/16), z_dim)
        

    def encode(self, img):
        out = self.model_enc(img)
        out = out.view(out.size(0),-1)
      
        return self.fc_mu(out),self.fc_var(out)
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
    
class Decoder(nn.Module):

    def __init__(self,z_dim,c_dim,img_size):
        super(Decoder, self).__init__()
        self.img_4 = img_size/4
        self.fc = nn.Sequential(
                nn.Linear(z_dim,int(self.img_4*self.img_4*64)),
                nn.ReLU(),
                )
        
        self.model = nn.Sequential(
            nn.ConvTranspose2d( 64, 64, 4, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d( 64, 64, 4, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d( 64, int(c_dim), 4, stride=2, padding=1),
            nn.BatchNorm2d(int(c_dim)),
            nn.Sigmoid()
        )
        

    def forward(self, z):

        batch_size = z.shape[0]
        temp_var = self.fc(z)
        temp_var = temp_var.view(batch_size,64,int(self.img_4),int(self.img_4))
        img= self.model(temp_var)
        return img
    
def joint_uncond(params, decoder, classifier, query_label, device):
    eps = 1e-8
    I = 0.0
    q = torch.zeros(params['M']).to(device)
    zs = np.zeros((params['Nalpha']*params['Nbeta'], params['z_dim']))
    for i in range(0, params['Nalpha']):
        alpha = np.random.randn(params['K'])
        zs = np.zeros((params['Nbeta'],params['z_dim']))  
        for j in range(0, params['Nbeta']):
            beta = np.random.randn(params['L'])
            zs[j,:params['K']] = alpha
            zs[j,params['K']:] = beta
        # decode and classify batch of Nbeta samples with same alpha
        xhat = decoder(torch.from_numpy(zs).float().to(device))
        yhat = torch.sigmoid(classifier(xhat))[:,query_label]
        yhat = torch.stack([1.0 - yhat, yhat], dim=1)
        p = 1./float(params['Nbeta']) * torch.sum(yhat,0) # estimate of p(y|alpha)
        I = I + 1./float(params['Nalpha']) * torch.sum(torch.mul(p, torch.log(p+eps)))
        q = q + 1./float(params['Nalpha']) * p # accumulate estimate of p(y)
    I = I - torch.sum(torch.mul(q, torch.log(q+eps)))
    negCausalEffect = -I
    info = {"xhat" : xhat, "yhat" : yhat}
    return negCausalEffect, info

def set_deterministic(seed):
    # seed by default is None
    if seed is not None:
        print(f"Deterministic with seed = {seed}")
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        print("Non-deterministic")

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('linear') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class Linear_Encoder(nn.Module):
 
    def __init__(self,x_dim,z_dim):
        super(Linear_Encoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(x_dim),512),
            nn.ReLU()
        )
        
        self.f_mu = nn.Linear(512, z_dim)
        self.f_var = nn.Linear(512, z_dim)


    def encode(self, img):
        h1 = self.model(img)
        return self.f_mu(h1),self.f_var(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
    
class Linear_Decoder(nn.Module):

    def __init__(self,x_dim,z_dim):
        super(Linear_Decoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(z_dim,512),
            nn.ReLU(),
            nn.Linear(512,x_dim)
        )

    def forward(self, z):
        img= self.model(z)
        return img 
    
class SoftCrossEntropy(nn.Module):
    def __init__(self, reduce=True):
        super(SoftCrossEntropy, self).__init__()
        self.reduce = reduce

    def forward(self, y, z):
        loss = - torch.sum(y * F.log_softmax(z, dim=1), dim=1)
        if self.reduce:
            return loss.mean()
        else:
            return loss