import torch
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import torch.distributions as dist
from woods.layers.VAE_EncDec import Encoder, Decoder

class PriorBlock(nn.Module):
    def __init__(self, d_dim, zd_dim):
        super(PriorBlock, self).__init__()
        self.d_dim = d_dim

        self.fc1 = nn.Sequential(
            nn.Linear(d_dim, zd_dim, bias=False), 
            nn.BatchNorm1d(zd_dim), 
            nn.ReLU()
        )

        self.fc21 = nn.Sequential(
            nn.Linear(zd_dim, zd_dim)
        )

        self.fc22 = nn.Sequential(
            nn.Linear(zd_dim, zd_dim), 
            nn.Softplus()
        )

        self.init_weights()

    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.fc1[0].weight)
        torch.nn.init.xavier_uniform_(self.fc21[0].weight)
        self.fc21[0].bias.data.zero_()
        torch.nn.init.xavier_uniform_(self.fc22[0].weight)
        self.fc22[0].bias.data.zero_()

    def forward(self, d):
        d_onehot = torch.zeros(d.shape[0], self.d_dim)
        for idx, val in enumerate(d):
            d_onehot[idx][val.item()] = 1

        d_onehot = d_onehot.to(d.device)

        hidden = self.fc1(d_onehot)
        zd_loc = self.fc21(hidden)
        zd_scale = self.fc22(hidden) + 1e-7

        return zd_loc, zd_scale

# Auxiliary tasks
class Projection(nn.Module):
    def __init__(self, d_dim, zd_dim):
        super(Projection, self).__init__()

        self.fc1 = nn.Linear(zd_dim, d_dim)

        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.fc1.bias.data.zero_()

    def forward(self, zd):
        h = F.relu(zd)
        loc_d = self.fc1(h)
        return loc_d


class GILE(nn.Module):
    def __init__(self, dataset, model_hparams, input_size = None):
        super(GILE, self).__init__()
        self.dataset = dataset
        self.device = model_hparams['device']
        self.zd_dim = model_hparams['d_AE']
        self.zx_dim = 0 # keep but not used
        self.zy_dim = model_hparams['d_AE']
        self.fc_dim = model_hparams['fc_dim']
        self.kernel_size = model_hparams['kernel_size']
        self.d_dim = dataset.get_nb_training_domains()
        self.y_dim = self.dataset.OUTPUT_SIZE
        self.input_channel = np.prod(dataset.INPUT_SHAPE) if input_size is None else input_size # configs.n_features
        self.len_seq = self.dataset.SEQ_LEN

        self.z_dim = self.zd_dim + self.zy_dim

        self.decoder = Decoder(self.z_dim, self.fc_dim, self.input_channel, self.kernel_size, output_channels=[64,128,512,1024])
        self.domain_prior = PriorBlock(self.d_dim, self.zd_dim)
        self.label_prior = PriorBlock(self.y_dim, self.zy_dim)

        self.specific_encoder = Encoder(output_dim=self.zd_dim, input_dim=self.input_channel, fc_dim=self.fc_dim, kernel_size=self.kernel_size,out_channels=[1024, 512, 128, 64])
        self.invariant_encoder = Encoder(output_dim=self.zy_dim, input_dim=self.input_channel, fc_dim=self.fc_dim, kernel_size=self.kernel_size, out_channels=[1024, 512, 128, 64])

        self.domain_projection = Projection(self.d_dim, self.zd_dim)
        self.label_projection = Projection(self.y_dim, self.zy_dim)

        self.beta_d = model_hparams['beta_d']
        self.beta_y = model_hparams['beta_y']

        self.weight_true = model_hparams['weight_true']
        self.weight_false = model_hparams['weight_false']
        self.aux_loss_multiplier_y = model_hparams['aux_loss_multiplier_y']
        self.aux_loss_multiplier_d = model_hparams['aux_loss_multiplier_d']

    def forward(self, x, y=None, d=None):
        x = x.permute(0, 2, 1)

        if self.training:
            assert y is not None and d is not None
            
            x_recon, d_hat, y_hat, qzd, pzd, zd, _, _, _, qzy, pzy, zy = self.forward_train(d, x, y)
            
            if x_recon.shape[-1] > x.shape[-1]:
                x_recon = x_recon[:, :, :x.shape[-1]]
            # reconstruction loss
            reconstruction_loss = F.mse_loss(x_recon, x.float())

            # KL divergence between q(z|x) and p(z)
            domain_kl_divergence = torch.sum(pzd.log_prob(zd) - qzd.log_prob(zd))
            label_kl_divergence = torch.sum(pzy.log_prob(zy) - qzy.log_prob(zy))

            y = y.squeeze(1)
            # Auxiliary losses
            domain_cls = F.cross_entropy(d_hat, d, reduction='sum')
            label_cls = F.cross_entropy(y_hat, y, reduction='sum')

            auxil_loss = reconstruction_loss \
               - self.beta_d * domain_kl_divergence - self.beta_y * label_kl_divergence \
               + self.aux_loss_multiplier_d * domain_cls + self.aux_loss_multiplier_y * label_cls

            constraint_loss = self.constraint_loss(x)
            # print("loss_ELOB:", reconstruction_loss - self.beta_d * domain_kl_divergence - self.beta_y * label_kl_divergence, " auxil_loss:", self.aux_loss_multiplier_d * domain_cls + self.aux_loss_multiplier_y * label_cls, " loss_cls:", domain_cls, " constraint_loss:", constraint_loss)
            return y_hat, auxil_loss, constraint_loss
        else:
            return self.forward_test(x)

    def forward_train(self, d, x, y):
        # Encode q(z|x)
        zd_mean, zd_scale, _, _ = self.specific_encoder(x)
        zy_mean, zy_scale, idxs_y, sizes_y = self.invariant_encoder(x)

        # Reparameterization trick
        domain_space = dist.Normal(zd_mean, zd_scale) # domain posterior distribution
        domain_feature = domain_space.rsample() # z = mu + sigma * epsilon
        invariant_space = dist.Normal(zy_mean, zy_scale) # invariant posterior distribution
        invariant_feature = invariant_space.rsample()

        # Decode p(x|z)
        z = torch.cat((domain_feature, invariant_feature), dim=-1)
        x_recon = self.decoder(z, idxs_y, sizes_y)

        # Prior Distributions learn from label data d, y
        domain_prior_mean, domain_prior_scale = self.domain_prior(d)
        label_prior_mean, label_prior_scale = self.label_prior(y)
        # x prior normal distribution N~(0,1)
        x_prior_mean, x_prior_scale = torch.zeros(zd_mean.size()[0], self.zx_dim).cuda(),\
                                   torch.ones(zd_scale.size()[0], self.zx_dim).cuda()

        domain_prior_distribution = dist.Normal(domain_prior_mean, domain_prior_scale)
        label_prior_distribution = dist.Normal(label_prior_mean, label_prior_scale)
        x_prior_distribution = dist.Normal(x_prior_mean, x_prior_scale)

        # Auxiliary losses
        d_hat = self.domain_projection(domain_feature)
        y_hat = self.label_projection(invariant_feature)

        return x_recon, d_hat, y_hat, domain_space, domain_prior_distribution, domain_feature, None, x_prior_distribution, None, invariant_space, label_prior_distribution, invariant_feature

    def constraint_loss(self, x):
        with torch.no_grad():
            zd_mean, zd_scale, _, _ = self.specific_encoder(x)
            zd = zd_mean
            domain_prob = F.softmax(self.domain_projection(zd), dim=1)
            # the maximum predicted class probability
            _, ind = torch.topk(domain_prob, 1) 
            # convert the digit(s) to one-hot tensor(s)
            d = x.new_zeros(domain_prob.size()) 
            pred_d = d.scatter_(1, ind, 1.0)


            zy_mean, zy_scale, _, _ = self.invariant_encoder(x)
            zy = zy_mean
            label_prob = F.softmax(self.label_projection(zy), dim=1)
            _, ind = torch.topk(label_prob, 1)
            y = x.new_zeros(label_prob.size())
            pred_y = y.scatter_(1, ind, 1.0)

            # cross projection for constraint loss
            alpha_y2d = F.softmax(self.domain_projection(zy), dim=1)
            _, ind = torch.topk(alpha_y2d, 1)
            d_false = x.new_zeros(alpha_y2d.size())
            pred_d_false = d_false.scatter_(1, ind, 1.0)

            alpha_d2y = F.softmax(self.label_projection(zd), dim=1)
            _, ind = torch.topk(alpha_d2y, 1)
            y_false = x.new_zeros(alpha_d2y.size())
            pred_y_false = y_false.scatter_(1, ind, 1.0)

        loss_classify_true = self.weight_true * (F.cross_entropy(pred_d, d, reduction='sum') + F.cross_entropy(pred_y, y, reduction='sum'))
        loss_classify_false = self.weight_false * (F.cross_entropy(pred_d_false, d, reduction='sum') + F.cross_entropy(pred_y_false, y, reduction='sum'))

        loss = loss_classify_true - loss_classify_false
        loss.requires_grad = True

        return loss

    def forward_test(self, x):
        B = x.shape[0]
        zy_mean, zy_scale, _, _ = self.invariant_encoder(x)
        zy = zy_mean
        label_prob = F.softmax(self.label_projection(zy), dim=1)
        _, ind = torch.topk(label_prob, 1)
        y = x.new_zeros(label_prob.size())
        pred_y = y.scatter_(1, ind, 1.0)
        logits = pred_y.unsqueeze(1)
        features = zy.reshape(B, 1, -1)
        return logits, features

    def get_features(self, x):
        zy_mean, zy_scale, _, _ = self.invariant_encoder(x)
        invariant_space = dist.Normal(zy_mean, zy_scale)
        zy = invariant_space.rsample()
        return zy


