# @Author  : Peizhao Li
# @Contact : peizhaoli05@gmail.com

import numpy as np

import torch
from torch import nn
from torch.nn import Parameter
from sklearn.cluster import KMeans
from torch.autograd import Variable
import torch.nn.functional as F

from utils import init_weights
from scipy.optimize import linear_sum_assignment
import torchvision.models as models

def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)


def grl_hook(coeff):
    def fun1(grad):
        return -coeff * grad.clone()

    return fun1


def adv_loss(features, ad_net, N_0, N_1):
    ad_out = ad_net(features)
    dc_target = torch.from_numpy(np.array([[1]] * N_0 + [[0]] * N_1)).float().cuda()
    return nn.BCELoss()(ad_out, dc_target)


class Encoder(nn.Module):
    def __init__(self, input_dim = 1, input_shape = 32):
        super(Encoder, self).__init__()
        self.input_shape = input_shape

        self.conv1 = nn.Conv2d(input_dim, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(16)

        self.fc1 = nn.Linear(int(input_shape/4 * input_shape/4 * 16), 512)
        self.fc_bn1 = nn.BatchNorm1d(512)
        self.fc21 = nn.Linear(512, 64)
        self.fc22 = nn.Linear(512, 64)

        self.relu = nn.LeakyReLU()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0.1)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x):
        # W X H
        conv1 = self.relu(self.bn1(self.conv1(x)))
        # W/2 X H/2
        conv2 = self.relu(self.bn2(self.conv2(conv1)))
        # W/2 X H/2
        conv3 = self.relu(self.bn3(self.conv3(conv2)))
        # W/4 X H/4
        conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, int(self.input_shape/4 * self.input_shape/4 * 16))

        fc1 = self.relu(self.fc_bn1(self.fc1(conv4)))
        mu, logvar = self.fc21(fc1), self.fc22(fc1)
        z = self.reparameterize(mu, logvar)

        return z, mu, logvar

    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": 1}]

class Decoder(nn.Module):
    def __init__(self, input_dim = 1, input_shape = 32):
        super(Decoder, self).__init__()
        
        self.input_shape = input_shape
        
#         self.conv0 = nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, padding=1, bias=False)
#         self.bn0 = nn.BatchNorm2d(1)
        self.conv1 = nn.ConvTranspose2d(16, input_dim, kernel_size=3, stride=1, padding=2, bias=False)
        self.bn1 = nn.BatchNorm2d(input_dim)
        self.conv2 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.ConvTranspose2d(32, 32, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.ConvTranspose2d(16, 32, kernel_size=5, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(32)

        

        self.fc2 = nn.Linear(512, int(input_shape/4 * input_shape/4 * 16))
        self.fc1_bn = nn.BatchNorm1d(512)
        self.fc1 = nn.Linear(64, 512)

        self.relu = nn.LeakyReLU()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0.1)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)

    def forward(self, x):
        x = self.relu(self.fc1_bn(self.fc1(x)))
        x = self.fc2(x).view(-1, 16, int(self.input_shape/4), int(self.input_shape/4))
        
        conv4 = self.relu(self.bn4(self.conv4(x)))
        conv3 = self.relu(self.bn3(self.conv3(conv4)))
        conv2 = self.relu(self.bn2(self.conv2(conv3)))
        x = self.relu(self.bn1(self.conv1(conv2)))
#         x = self.relu(self.bn1(self.conv0(conv1)))

        return x

    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": 1}]

class ClusterAssignment(nn.Module):
    def __init__(self, cluster_number, embedding_dimension, alpha, cluster_centers):
        super(ClusterAssignment, self).__init__()
        self.embedding_dimension = embedding_dimension
        self.cluster_number = cluster_number
        self.alpha = alpha
        if cluster_centers is None:
            self.cluster_centers = Parameter(torch.zeros((cluster_number, embedding_dimension), dtype=torch.float))
            self.initialized = False
        else:
            print('cluster centers loaded')
            self.cluster_centers = Parameter(cluster_centers)
            self.initialized = True
        

    def forward(self, batch):
        if self.initialized is False:
            km = KMeans(n_clusters=self.cluster_number, random_state = 10).fit(batch.detach().cpu().numpy())
            self.cluster_centers = Parameter(torch.tensor(km.cluster_centers_).to(batch.device))
            
            print('centers initialized')
            
            self.initialized = True
        
        norm_squared = torch.sum((batch.unsqueeze(1) - self.cluster_centers) ** 2, 2)
        numerator = 1.0 / (1.0 + (norm_squared / self.alpha))

        return numerator / torch.sum(numerator, dim=1, keepdim=True)
    
class DFC(nn.Module):
    def __init__(self, cluster_number, hidden_dimension, alpha=1, cluster_centers = None):
        super(DFC, self).__init__()
        self.hidden_dimension = hidden_dimension
        self.cluster_number = cluster_number
        self.alpha = alpha
        self.coupled = False
        self.assignment_0 = ClusterAssignment(cluster_number, self.hidden_dimension, alpha, cluster_centers=cluster_centers)

    def forward(self, batch, sens = 0):
        return self.assignment_0(batch)

    def set_centers(self, centers):
        self.assignment_0.cluster_centers = centers
        
    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": 1e-1}]


class AdversarialNetwork(nn.Module):
    def __init__(self, in_feature, hidden_size, max_iter, lr_mult):
        super(AdversarialNetwork, self).__init__()
        self.ad_layer1 = nn.Linear(in_feature, hidden_size)
        self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
        self.ad_layer3 = nn.Linear(hidden_size, 1)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()
        self.apply(init_weights)
        self.iter_num = 0
        self.alpha = 10
        self.low = 0.0
        self.high = 1.0
        self.max_iter = float(max_iter)
        self.lr_mult = lr_mult

    def forward(self, x, ):
        if self.training:
            self.iter_num += 1
        coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
        x = x * 1.0
        x.register_hook(grl_hook(coeff))
        x = self.ad_layer1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.ad_layer2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        y = self.ad_layer3(x)
        y = self.sigmoid(y)
        return y

    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": self.lr_mult}]

    
class Encoder_RGB(nn.Module):
    #ResNet Base
    def __init__(self, fc_hidden1=1024, fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):
        super(Encoder_RGB, self).__init__()

        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim

        # CNN architechtures
        self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
        self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3)      # 2d kernal size
        self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2)      # 2d strides
        self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0)  # 2d padding

        # encoding components
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.fc1 = nn.Linear(resnet.fc.in_features, self.fc_hidden1)
        self.bn1 = nn.BatchNorm1d(self.fc_hidden1, momentum=0.01)
        self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)
        self.bn2 = nn.BatchNorm1d(self.fc_hidden2, momentum=0.01)
        # Latent vectors mu and sigma
        self.fc3_mu = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)      # output = CNN embedding latent variables
        self.fc3_logvar = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)  # output = CNN embedding latent variables
        
        self.relu = nn.ReLU(inplace=True)
        
    def encode(self, x):
        x = self.resnet(x)  # ResNet
        x = x.view(x.size(0), -1)  # flatten output of conv

        # FC layers
        x = self.bn1(self.fc1(x))
        x = self.relu(x)
        x = self.bn2(self.fc2(x))
        x = self.relu(x)
        # x = F.dropout(x, p=self.drop_p, training=self.training)
        mu, logvar = self.fc3_mu(x), self.fc3_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)

        return z, mu, logvar
    
    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": 1}]
    
class Decoder_RGB(nn.Module):
    #ResNet Base
    def __init__(self, fc_hidden1=1024, fc_hidden2=768, drop_p=0.3, CNN_embed_dim=256):
        super(Decoder_RGB, self).__init__()

        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim

        # CNN architechtures
        self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
        self.k1, self.k2, self.k3, self.k4 = (5, 5), (3, 3), (3, 3), (3, 3)      # 2d kernal size
        self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2)      # 2d strides
        self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0)  # 2d padding

        # Sampling vector
        self.fc4 = nn.Linear(self.CNN_embed_dim, self.fc_hidden2)
        self.fc_bn4 = nn.BatchNorm1d(self.fc_hidden2)
        self.fc5 = nn.Linear(self.fc_hidden2, 64 * 4 * 4)
        self.fc_bn5 = nn.BatchNorm1d(64 * 4 * 4)
        self.relu = nn.LeakyReLU(inplace=True)

        # Decoder
        self.convTrans6 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=self.k4, stride=self.s4,
                               padding=self.pd4),
            nn.BatchNorm2d(32, momentum=0.01),
            nn.ReLU(inplace=True),
        )
        self.convTrans7 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=8, kernel_size=self.k3, stride=self.s3,
                               padding=self.pd3),
            nn.BatchNorm2d(8, momentum=0.01),
            nn.ReLU(inplace=True),
        )

        self.convTrans8 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=8, out_channels=3, kernel_size=self.k2, stride=self.s2,
                               padding=self.pd2),
            nn.BatchNorm2d(3, momentum=0.01),
            nn.Sigmoid()    # y = (y1, y2, y3) \in [0 ,1]^3
        )

    def decode(self, z):
        x = self.relu(self.fc_bn4(self.fc4(z)))
        x = self.relu(self.fc_bn5(self.fc5(x))).view(-1, 64, 4, 4)
        x = self.convTrans6(x)
        x = self.convTrans7(x)
        x = self.convTrans8(x)
        x = F.interpolate(x, size=(224, 224), mode='bilinear')
        return x

    def forward(self, z):
        x_reconst = self.decode(z)

        return x_reconst
    
class Encoder_tab(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder_tab, self).__init__()


        self.linear1 = nn.Linear(input_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        
        self.linear3_1 = nn.Linear(64, 64)
        self.linear3_2 = nn.Linear(64, 64)

        self.relu = nn.ReLU()


    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x):
        
        x = self.relu(self.bn1(self.linear1(x)))
        x = self.relu(self.bn2(self.linear2(x)))
        
        mu, logvar = self.linear3_1(x), self.linear3_2(x)
        
        z = self.reparameterize(mu, logvar)
        
        return z, mu, logvar

    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": 1}]
    
    
