import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from base.base_net import BaseNet
from torch.nn import Parameter
from torch.autograd import Variable

def entropy(p):
    if (len(p.size())) == 2:
        return - torch.sum(p * torch.log(p+1e-18)) / float(len(p))
    elif len(p.size()) == 1:
        return - torch.sum(p * torch.log(p+1e-18))
    else:
        raise NotImplementedError

def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-18
    return d


class MNIST_MLP(BaseNet):

    def __init__(self):
        super().__init__()
        self.input_dim = 28 * 28
        self.rep_dim = 10
        self.n_clusters = 10
        self.alpha = 1
        self.n_bins = 2
        self.mu = Parameter(torch.Tensor(self.n_clusters, self.rep_dim))
        self.mu_smooth = Parameter(torch.Tensor(self.n_bins, self.rep_dim))

        self.encoder = nn.Sequential(nn.Linear(self.input_dim, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 2000),
                                     nn.ReLU(),
                                     nn.Linear(2000, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, self.rep_dim))

        self.decoder = nn.Sequential(nn.Linear(self.rep_dim, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 2000),
                                     nn.ReLU(),
                                     nn.Linear(2000, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, self.input_dim))

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.encoder(x)
        return x

    def reconstruct(self, x):
        x = x.view(-1, 28 * 28)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def get_q_smooth(self, z):
        # compute q -> Nxn_bins
        q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu_smooth) ** 2, dim=2) / self.alpha)
        q = q ** (self.alpha + 1.0) / 2.0
        q = q / torch.sum(q, dim=1, keepdim=True)
        q_smooth = q
        return q_smooth

    def soft_assign(self, x):
        q = 1.0 / (1.0 + torch.sum((x.unsqueeze(1) - self.mu) ** 2, dim=2) / self.alpha)
        q = q ** (self.alpha + 1.0) / 2.0
        q = q / torch.sum(q, dim=1, keepdim=True)
        return q

    def target_distribution_smooth(self, q):
        p = torch.pow(q + 0.00000001, 0.01) / torch.sum(q, dim=0)
        p = p / torch.sum(p, dim=1, keepdim=True)
        return p

    def target_distribution(self, q):
        p = q ** 2 / torch.sum(q, dim=0)
        p = p / torch.sum(p, dim=1, keepdim=True)
        return p

    def cluster_loss(self, p, q):
        def kld(target, pred):
            return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=1))

        kldloss = kld(p, q)
        return kldloss

    def get_fair_centroids(self, z, attr):
        attr = np.squeeze(np.array(attr))
        uniq_vals = np.unique(attr)
        nbins = np.size(uniq_vals)
        _, z_dim = np.shape(z)
        fair_cent = np.zeros((nbins, z_dim))
        for i in range(nbins):
            val = uniq_vals[i]
            inds = np.squeeze(np.where(attr == val))
            z_i = z[inds, :]
            foo = np.shape(z_i)
            fair_cent[i, :] = np.mean(z_i, axis=0)
        return fair_cent

    def get_cluster_centroids(self, z, p):
        p_ = p.detach().cpu().numpy()
        PTP_inv = inv(np.matmul(p_.T, p_))
        P = torch.mm(p, torch.from_numpy(PTP_inv).cuda())
        zT = torch.transpose(z, 1, 0).cuda()
        zTP = torch.mm(zT, P)
        return torch.transpose(zTP, 1, 0)

    def encodeBatch(self, dataloader, islabel=False):
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            self.cuda()

        encoded = []
        attr = []
        self.eval()
        for data in dataloader:
            inputs, _, _, psvs = data
            inputs = inputs.cuda()
            z = self.forward(inputs)
            encoded.append(z.data.cpu())
            attr.extend(psvs.cpu().data.numpy().tolist())
        encoded = torch.cat(encoded, dim=0)
        out = encoded
        return out, attr


class MNIST_MLP_Autoencoder(BaseNet):

    def __init__(self):
        super().__init__()
        self.input_dim = 28 * 28
        self.rep_dim = 10
        self.n_clusters = 10
        self.n_bins = 2
        self.mu = Parameter(torch.Tensor(self.n_clusters, self.rep_dim))
        self.mu_smooth = Parameter(torch.Tensor(self.n_bins, self.rep_dim))

        # Encoder
        self.encoder = nn.Sequential(nn.Linear(self.input_dim, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 2000),
                                     nn.ReLU(),
                                     nn.Linear(2000, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, self.rep_dim))

        self.decoder = nn.Sequential(nn.Linear(self.rep_dim, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, 2000),
                                     nn.ReLU(),
                                     nn.Linear(2000, 500),
                                     nn.ReLU(),
                                     nn.Linear(500, self.input_dim))

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.encoder(x)
        x = self.decoder(x)

        return x


class MNIST_LeNet(BaseNet):
    def __init__(self):
        super().__init__()
        self.rep_dim = 10
        self.prop_eps = 0.25
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 8, 5,  padding=2)
        self.bn1 = nn.BatchNorm2d(8, eps=1e-04)
        self.conv2 = nn.Conv2d(8, 4, 5, padding=2)
        self.bn2 = nn.BatchNorm2d(4, eps=1e-04)
        self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim)
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(2, 4, 5,  padding=2)
        self.bn3 = nn.BatchNorm2d(4, eps=1e-04)
        self.deconv2 = nn.ConvTranspose2d(4, 8, 5,  padding=3)
        self.bn4 = nn.BatchNorm2d(8, eps=1e-04)
        self.deconv3 = nn.ConvTranspose2d(8, 1, 5, padding=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x
    
    def compute_entropy(self, x):
        p = F.softmax(self.forward(x), dim=1)
        p_ave = torch.sum(p, dim=0) / len(x)
        return entropy(p), entropy(p_ave)

    def kl(self, p, q):
        # compute KL divergence between p and q
        return torch.sum(p * torch.log((p + 1e-8) / (q + 1e-8))) / float(len(p))

    def distance(self, y0, y1):
        # compute KL divergence between the outputs of the newtrok
        return self.kl(F.softmax(y0,dim=1), F.softmax(y1,dim=1))

    def loss_unlabeled(self, x, eps_list):
        # to use enc_aux_noubs
        L = self.vat(x, eps_list)
        return L

    def vat(self,x, eps_list, xi=10, Ip=1):
        with torch.no_grad():
            y = self.forward(Variable(x))
        #d = torch.randn((x.size()[0],x.size()[1]))
        d = torch.rand(x.shape).to(x.device)
        d = _l2_normalize(d)
        for ip in range(Ip):
            d_var = Variable(d)
            d_var = d_var.to('cuda')
            d_var.requires_grad_(True)
            y_p = self.forward(x + xi * d_var)
            kl_loss = self.distance(y,y_p)
            kl_loss.backward(retain_graph=True)
            d = d_var.grad
            d = _l2_normalize(d)
        d_var = d
        d_var = d_var.to('cuda')
        eps = self.prop_eps * eps_list
        eps = eps.view(-1,1, 1, 1)
        y_2 = self.forward(x + eps*d_var)
        return self.distance(y,y_2)

class MNIST_LeNet_Autoencoder(BaseNet):
    def __init__(self):
        super().__init__()
        self.rep_dim = 32
        self.pool = nn.MaxPool2d(2, 2)
        self.n_clusters = 10
        self.mu = Parameter(torch.Tensor(self.n_clusters, self.rep_dim))

        # Encoder (must match the LeNet network above)
        self.conv1 = nn.Conv2d(1, 8, 5, padding=2)
        self.bn1 = nn.BatchNorm2d(8, eps=1e-04)
        self.conv2 = nn.Conv2d(8, 4, 5, padding=2)
        self.bn2 = nn.BatchNorm2d(4, eps=1e-04)
        self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim)

        # Decoder
        self.deconv1 = nn.ConvTranspose2d(2, 4, 5, padding=2)
        self.bn3 = nn.BatchNorm2d(4, eps=1e-04)
        self.deconv2 = nn.ConvTranspose2d(4, 8, 5, padding=3)
        self.bn4 = nn.BatchNorm2d(8, eps=1e-04)
        self.deconv3 = nn.ConvTranspose2d(8, 1, 5, padding=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = x.view(x.size(0), int(self.rep_dim / 16), 4, 4)
        x = F.interpolate(F.leaky_relu(x), scale_factor=2)
        x = self.deconv1(x)
        x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
        x = self.deconv2(x)
        x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
        x = self.deconv3(x)
        x = torch.sigmoid(x)

        return x

class MNIST_RIM(BaseNet):

    def __init__(self):
        super().__init__()
        self.prop_eps = 0.25
        self.rep_dim = 10
        self.n_clusters = 10
        self.fc1 = nn.Linear(28*28, 1200)
        torch.nn.init.normal_(self.fc1.weight,std=0.1*math.sqrt(2/(28*28)))
        self.fc1.bias.data.fill_(0)
        self.fc2 = nn.Linear(1200, 1200)
        torch.nn.init.normal_(self.fc2.weight,std=0.1*math.sqrt(2/1200))
        self.fc2.bias.data.fill_(0)
        self.fc3 = nn.Linear(1200, 10)
        torch.nn.init.normal_(self.fc3.weight,std=0.0001*math.sqrt(2/1200))
        self.fc3.bias.data.fill_(0)
        self.bn1=nn.BatchNorm1d(1200, affine=True)
        self.bn2=nn.BatchNorm1d(1200, affine=True)

        self.classifier = nn.Linear(self.rep_dim, self.n_clusters)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

    def compute_entropy(self, x):
        p = F.softmax(self.forward(x), dim=1)
        p_ave = torch.sum(p, dim=0) / len(x)
        return entropy(p), entropy(p_ave)

    def kl(self, p, q):
        # compute KL divergence between p and q
        return torch.sum(p * torch.log((p + 1e-8) / (q + 1e-8))) / float(len(p))

    def distance(self, y0, y1):
        # compute KL divergence between the outputs of the newtrok
        return self.kl(F.softmax(y0,dim=1), F.softmax(y1,dim=1))

    def loss_unlabeled(self, x, eps_list):
        # to use enc_aux_noubs
        L = self.vat(x, eps_list)
        return L

    def vat(self,x, eps_list, xi=10, Ip=1):
        with torch.no_grad():
            y = self.forward(Variable(x))
        d = torch.randn((x.size()[0],x.size()[1]))
        d = F.normalize(d, p=2, dim=1)
        for ip in range(Ip):
            d_var = Variable(d)
            d_var = d_var.to('cuda')
            d_var.requires_grad_(True)
            y_p = self.forward(x + xi * d_var)
            kl_loss = self.distance(y,y_p)
            kl_loss.backward(retain_graph=True)
            d = d_var.grad
            d = F.normalize(d, p=2, dim=1)
        d_var = d
        d_var = d_var.to('cuda')
        eps = self.prop_eps * eps_list
        eps = eps.view(-1,1)
        y_2 = self.forward(x + eps*d_var)
        return self.distance(y,y_2)

class CREDIT_RIM(BaseNet):
    def __init__(self):
        super().__init__()
        self.prop_eps = 0.2
        self.rep_dim = 2
        self.n_clusters = 2
        self.fc1 = nn.Linear(14, 50)
        torch.nn.init.normal_(self.fc1.weight,std=0.1*math.sqrt(2/(50)))
        self.fc1.bias.data.fill_(0)
        self.fc2 = nn.Linear(50, 50)
        torch.nn.init.normal_(self.fc2.weight,std=0.1*math.sqrt(2/50))
        self.fc2.bias.data.fill_(0)
        self.fc3 = nn.Linear(50, 2)
        torch.nn.init.normal_(self.fc3.weight,std=0.0001*math.sqrt(2/50))
        self.fc3.bias.data.fill_(0)
        self.bn1=nn.BatchNorm1d(50, affine=True)
        self.bn2=nn.BatchNorm1d(50, affine=True)

        self.classifier = nn.Linear(self.rep_dim, self.n_clusters)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    def compute_entropy(self, x):
        p = F.softmax(self.forward(x), dim=1)
        p_ave = torch.sum(p, dim=0) / len(x)
        return entropy(p), entropy(p_ave)

    def kl(self, p, q):
        # compute KL divergence between p and q
        return torch.sum(p * torch.log((p + 1e-8) / (q + 1e-8))) / float(len(p))

    def distance(self, y0, y1):
        # compute KL divergence between the outputs of the newtrok
        return self.kl(F.softmax(y0,dim=1), F.softmax(y1,dim=1))

    def loss_unlabeled(self, x, eps_list):
        # to use enc_aux_noubs
        L = self.vat(x, eps_list)
        return L

    def vat(self,x, eps_list, xi=10, Ip=1):
        with torch.no_grad():
            y = self.forward(Variable(x))
        d = torch.randn((x.size()[0],x.size()[1]))
        d = F.normalize(d, p=2, dim=1)
        for ip in range(Ip):
            d_var = Variable(d)
            d_var = d_var.to('cuda')
            d_var.requires_grad_(True)
            y_p = self.forward(x + xi * d_var)
            kl_loss = self.distance(y,y_p)
            kl_loss.backward(retain_graph=True)
            d = d_var.grad
            d = F.normalize(d, p=2, dim=1)
        d_var = d
        d_var = d_var.to('cuda')
        eps = self.prop_eps * eps_list
        eps = eps.view(-1,1)
        y_2 = self.forward(x + eps*d_var)
        return self.distance(y,y_2)

class BANK_RIM(BaseNet):
    def __init__(self):
        super().__init__()
        self.prop_eps = 0.2
        self.rep_dim = 2
        self.n_clusters = 2
        self.fc1 = nn.Linear(3, 50)
        torch.nn.init.normal_(self.fc1.weight,std=0.1*math.sqrt(2/(50)))
        self.fc1.bias.data.fill_(0)
        self.fc2 = nn.Linear(50, 50)
        torch.nn.init.normal_(self.fc2.weight,std=0.1*math.sqrt(2/50))
        self.fc2.bias.data.fill_(0)
        self.fc3 = nn.Linear(50, 2)
        torch.nn.init.normal_(self.fc3.weight,std=0.0001*math.sqrt(2/50))
        self.fc3.bias.data.fill_(0)
        self.bn1=nn.BatchNorm1d(50, affine=True)
        self.bn2=nn.BatchNorm1d(50, affine=True)

        self.classifier = nn.Linear(self.rep_dim, self.n_clusters)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    def compute_entropy(self, x):
        p = F.softmax(self.forward(x), dim=1)
        p_ave = torch.sum(p, dim=0) / len(x)
        return entropy(p), entropy(p_ave)

    def kl(self, p, q):
        # compute KL divergence between p and q
        return torch.sum(p * torch.log((p + 1e-8) / (q + 1e-8))) / float(len(p))

    def distance(self, y0, y1):
        # compute KL divergence between the outputs of the newtrok
        return self.kl(F.softmax(y0,dim=1), F.softmax(y1,dim=1))

    def loss_unlabeled(self, x, eps_list):
        # to use enc_aux_noubs
        L = self.vat(x, eps_list)
        return L

    def vat(self,x, eps_list, xi=10, Ip=1):
        with torch.no_grad():
            y = self.forward(Variable(x))
        d = torch.randn((x.size()[0],x.size()[1]))
        d = F.normalize(d, p=2, dim=1)
        for ip in range(Ip):
            d_var = Variable(d)
            d_var = d_var.to('cuda')
            d_var.requires_grad_(True)
            y_p = self.forward(x + xi * d_var)
            kl_loss = self.distance(y,y_p)
            kl_loss.backward(retain_graph=True)
            d = d_var.grad
            d = F.normalize(d, p=2, dim=1)
        d_var = d
        d_var = d_var.to('cuda')
        eps = self.prop_eps * eps_list
        eps = eps.view(-1,1)
        y_2 = self.forward(x + eps*d_var)
        return self.distance(y,y_2)

class ADULT_RIM(BaseNet):
    def __init__(self):
        super().__init__()
        self.prop_eps = 0.2
        self.rep_dim = 2
        self.n_clusters = 2
        self.fc1 = nn.Linear(5, 50)
        torch.nn.init.normal_(self.fc1.weight,std=0.1*math.sqrt(2/(50)))
        self.fc1.bias.data.fill_(0)
        self.fc2 = nn.Linear(50, 50)
        torch.nn.init.normal_(self.fc2.weight,std=0.1*math.sqrt(2/50))
        self.fc2.bias.data.fill_(0)
        self.fc3 = nn.Linear(50, 2)
        torch.nn.init.normal_(self.fc3.weight,std=0.0001*math.sqrt(2/50))
        self.fc3.bias.data.fill_(0)
        self.bn1=nn.BatchNorm1d(50, affine=True)
        self.bn2=nn.BatchNorm1d(50, affine=True)

        self.classifier = nn.Linear(self.rep_dim, self.n_clusters)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

    def compute_entropy(self, x):
        p = F.softmax(self.forward(x), dim=1)
        p_ave = torch.sum(p, dim=0) / len(x)
        return entropy(p), entropy(p_ave)

    def kl(self, p, q):
        # compute KL divergence between p and q
        return torch.sum(p * torch.log((p + 1e-8) / (q + 1e-8))) / float(len(p))

    def distance(self, y0, y1):
        # compute KL divergence between the outputs of the newtrok
        return self.kl(F.softmax(y0,dim=1), F.softmax(y1,dim=1))

    def loss_unlabeled(self, x, eps_list):
        # to use enc_aux_noubs
        L = self.vat(x, eps_list)
        return L

    def vat(self,x, eps_list, xi=10, Ip=1):
        with torch.no_grad():
            y = self.forward(Variable(x))
        d = torch.randn((x.size()[0],x.size()[1]))
        d = F.normalize(d, p=2, dim=1)
        for ip in range(Ip):
            d_var = Variable(d)
            d_var = d_var.to('cuda')
            d_var.requires_grad_(True)
            y_p = self.forward(x + xi * d_var)
            kl_loss = self.distance(y,y_p)
            kl_loss.backward(retain_graph=True)
            d = d_var.grad
            d = F.normalize(d, p=2, dim=1)
        d_var = d
        d_var = d_var.to('cuda')
        eps = self.prop_eps * eps_list
        eps = eps.view(-1,1)
        y_2 = self.forward(x + eps*d_var)
        return self.distance(y,y_2)

class HAR_RIM(BaseNet):

    def __init__(self):
        super().__init__()
        self.prop_eps = 0.2
        self.rep_dim = 6
        self.n_clusters = 6
        self.fc1 = nn.Linear(561, 1200)
        torch.nn.init.normal_(self.fc1.weight,std=0.1*math.sqrt(2/(561)))
        self.fc1.bias.data.fill_(0)
        self.fc2 = nn.Linear(1200, 1200)
        torch.nn.init.normal_(self.fc2.weight,std=0.1*math.sqrt(2/1200))
        self.fc2.bias.data.fill_(0)
        self.fc3 = nn.Linear(1200, 6)
        torch.nn.init.normal_(self.fc3.weight,std=0.0001*math.sqrt(2/1200))
        self.fc3.bias.data.fill_(0)
        self.bn1=nn.BatchNorm1d(1200, affine=True)
        self.bn2=nn.BatchNorm1d(1200, affine=True)

        self.classifier = nn.Linear(self.rep_dim, self.n_clusters)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

    def compute_entropy(self, x):
        p = F.softmax(self.forward(x), dim=1)
        p_ave = torch.sum(p, dim=0) / len(x)
        return entropy(p), entropy(p_ave)

    def kl(self, p, q):
        # compute KL divergence between p and q
        return torch.sum(p * torch.log((p + 1e-8) / (q + 1e-8))) / float(len(p))

    def distance(self, y0, y1):
        # compute KL divergence between the outputs of the newtrok
        return self.kl(F.softmax(y0,dim=1), F.softmax(y1,dim=1))

    def loss_unlabeled(self, x, eps_list):
        # to use enc_aux_noubs
        L = self.vat(x, eps_list)
        return L

    def vat(self,x, eps_list, xi=10, Ip=1):
        with torch.no_grad():
            y = self.forward(Variable(x))
        d = torch.randn((x.size()[0],x.size()[1]))
        d = F.normalize(d, p=2, dim=1)
        for ip in range(Ip):
            d_var = Variable(d)
            d_var = d_var.to('cuda')
            d_var.requires_grad_(True)
            y_p = self.forward(x + xi * d_var)
            kl_loss = self.distance(y,y_p)
            kl_loss.backward(retain_graph=True)
            d = d_var.grad
            d = F.normalize(d, p=2, dim=1)
        d_var = d
        d_var = d_var.to('cuda')
        eps = self.prop_eps * eps_list
        eps = eps.view(-1,1)
        y_2 = self.forward(x + eps*d_var)
        return self.distance(y,y_2)


class DCC_LeNet(BaseNet):

    def __init__(self, n_clusters=2, rep_dim=32):
        super().__init__()
        self.n_clusters = n_clusters
        self.rep_dim = rep_dim
        self.pool = nn.MaxPool2d(2, 2)
        # Encoder
        self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
        self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
        self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
        self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
        self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(2, 4, 5, bias=False, padding=2)
        self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
        self.deconv2 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=3)
        self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
        self.deconv3 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)
        self.mu = Parameter(torch.Tensor(n_clusters, rep_dim))

    def soft_assign(self, z):
        q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu) ** 2, dim=2) / 1.0)
        q = q ** (1.0 + 1.0) / 2.0
        q = q / torch.sum(q, dim=1, keepdim=True)
        return q

    def target_distribution(self, q):
        p = q ** 2 / torch.sum(q, dim=0)
        p = p / torch.sum(p, dim=1, keepdim=True)
        return p

    def loss_function(self, p, q):
        def kld(target, pred):
            return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=1))

        loss = kld(p, q)
        return loss

    def encodeBatch(self, dataloader, islabel=False):
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            self.cuda()

        encoded = []
        self.eval()
        for data in dataloader:
            inputs, _, _, _ = data
            inputs = inputs.cuda()
            z, _, _ = self.forward(inputs)
            encoded.append(z.data.cpu())

        encoded = torch.cat(encoded, dim=0)
        out = encoded
        return out

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2(x)))
        x = x.view(x.size(0), -1)
        z = self.fc1(x)

        x = z.view(z.size(0), int(self.rep_dim / 16), 4, 4)
        x = F.interpolate(F.leaky_relu(x), scale_factor=2)
        x = self.deconv1(x)
        x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
        x = self.deconv2(x)
        x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
        x = self.deconv3(x)
        x = torch.sigmoid(x)

        q = self.soft_assign(z)
        return z, q, x


class DCC_LeNet_Autoencoder(BaseNet):

    def __init__(self):
        super().__init__()

        self.rep_dim = 32
        self.pool = nn.MaxPool2d(2, 2)

        # Encoder
        self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
        self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
        self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
        self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
        self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)

        # Decoder
        self.deconv1 = nn.ConvTranspose2d(2, 4, 5, bias=False, padding=2)
        self.bn3 = nn.BatchNorm2d(4, eps=1e-04, affine=False)
        self.deconv2 = nn.ConvTranspose2d(4, 8, 5, bias=False, padding=3)
        self.bn4 = nn.BatchNorm2d(8, eps=1e-04, affine=False)
        self.deconv3 = nn.ConvTranspose2d(8, 1, 5, bias=False, padding=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = x.view(x.size(0), int(self.rep_dim / 16), 4, 4)
        x = F.interpolate(F.leaky_relu(x), scale_factor=2)
        x = self.deconv1(x)
        x = F.interpolate(F.leaky_relu(self.bn3(x)), scale_factor=2)
        x = self.deconv2(x)
        x = F.interpolate(F.leaky_relu(self.bn4(x)), scale_factor=2)
        x = self.deconv3(x)
        x = torch.sigmoid(x)

        return x
