import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import torch
import torch.nn as nn
import torch.distributions as dist
import os


def mask_correlated_samples(N):
    mask = torch.ones((N, N))
    mask = mask.fill_diagonal_(0)
    for i in range(N//2):
        mask[i, N//2 + i] = 0
        mask[N//2 + i, i] = 0
    mask = mask.bool()
    return mask

def compute_feature(h_i, h_j, T=0.5):
    k = h_i.shape[0]
    h_i = h_i.reshape(-1, h_i.shape[-1])
    h_j = h_j.reshape(-1, h_j.shape[-1])
    batch_size = h_i.shape[0]
    N = 2 * batch_size
    h = torch.cat((h_i, h_j), dim=0)

    sim = torch.matmul(h, h.T) / T
    sim_i_j = torch.diag(sim, batch_size)
    sim_j_i = torch.diag(sim, -batch_size)

    positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
    mask = mask_correlated_samples(N)
    negative_samples = sim[mask].reshape(N, -1)

    labels = torch.zeros(N).to(positive_samples.device).long()
    logits = torch.cat((positive_samples, negative_samples), dim=1)
    loss = nn.CrossEntropyLoss(reduction='none')(logits, labels)
    loss = loss.reshape(k, -1, 2)
    return loss.mean(dim=-1)


class BASE_MVAE(nn.Module):
    def __init__(self, args, pseudo_samples_a, pseudo_samples_b):
        super(BASE_MVAE, self).__init__()
        self.z_dim = args.z_dim
        self.w_dim = args.w_dim
        self.latent_dim = args.z_dim + args.w_dim

        self.pw_params = nn.ParameterList([
            nn.Parameter(torch.zeros(1, self.w_dim), requires_grad=False),
            nn.Parameter(torch.ones(1, self.w_dim), requires_grad=False)
        ])

        self.pseudo_samples_a = nn.Parameter(pseudo_samples_a)
        self.pseudo_samples_b = nn.Parameter(pseudo_samples_b)

    def get_mixture_prior_params(self, samples):
        return self.cond_prior(samples)

    def match(self, batch_a, batch_b, direction='bi'):
        return self.run(batch_a, batch_b, direction, self.run_match)

    def unsup(self, batch_a, batch_b, direction):
        return self.run(batch_a, batch_b, direction, self.run_unsup)
    
    def compute_kl(self, locs_q, scale_q, locs_p=None, scale_p=None):
        if locs_p is None:
            locs_p = torch.zeros_like(locs_q)
        if scale_p is None:                                                                                                                                                                                                         
            scale_p = torch.ones_like(scale_q)

        kl = 0.5 * (2 * scale_p.log() - 2 * scale_q.log() + \
                    (locs_q - locs_p).pow(2) / scale_p.pow(2) + \
                    scale_q.pow(2) / scale_p.pow(2) - torch.ones_like(locs_q))
        return kl.mean(dim=-1)
    
    def compute_kl_enum(self, locs_q, scale_q, locs_p, scale_p,
                        z, logit_probs):
        log_qz = dist.Normal(locs_q, scale_q).log_prob(z).sum(dim=-1)
        locs_p = locs_p.unsqueeze(0).expand(locs_q.shape[0], -1, -1)
        scale_p = scale_p.unsqueeze(0).expand(locs_q.shape[0], -1, -1)

        mix = dist.Categorical(logits=logit_probs)
        comp = dist.Independent(dist.Normal(locs_p, scale_p), 1)
        gmm = dist.mixture_same_family.MixtureSameFamily(mix, comp)

        kl = log_qz - gmm.log_prob(z)

        return kl / z.size()[-1]

    def classifier_loss_img(self, data, targ, k=10, z_sample=None):
        bs = data.shape[0]
        post_params_mean, post_params_lv = self.encoder(data)
        post_z_mean, _ = torch.split(post_params_mean, [self.z_dim, self.w_dim], dim=-1)
        post_z_lv, _ = torch.split(post_params_lv, [self.z_dim, self.w_dim], dim=-1)
        z = dist.Normal(post_z_mean, post_z_lv).rsample(torch.Size([k]))
        w = torch.zeros(k, bs, self.w_dim).to(z.device)
        u = torch.cat((z, w), dim=-1).view(k, bs, self.latent_dim)
        preds = self.classifier(u)
        probs = self.likelihood_t(preds, targ)
        probs = probs.view(1, k, -1) # no_z x no_k x bs
        if z_sample is not None:
            z_sample = z_sample.view(-1, bs, self.z_dim)
            w_sample = torch.zeros(z_sample.shape[0], bs, self.w_dim).to(z_sample.device)
            u_sample = torch.cat((z_sample, w_sample), dim=-1)
            preds_samples = self.classifier(u_sample)
            probs_samples = self.likelihood_t(preds_samples, targ)
            probs = probs.expand(u_sample.shape[0], -1, -1)
            probs = torch.cat((probs, probs_samples.unsqueeze(1)), dim=1)
        log_qts = torch.logsumexp(probs, dim=1) - np.log(probs.shape[1])
        return log_qts
    
    def run_match(self, data, targ, k=1):
        u_post_params = self.encoder(data)
        u = dist.Normal(*u_post_params).rsample(torch.Size([k]))
        z, w = torch.split(u, [self.z_dim, self.w_dim], dim=-1)
        
        w_new = torch.zeros_like(w)
        u_new = torch.cat((z, w_new), dim=-1)
        pred = self.classifier(u_new)
        log_qtz = self.likelihood_t(pred, targ)

        c_prior_params = self.cond_prior(targ)
        prior_u = dist.Normal(*c_prior_params).rsample(torch.Size([k]))
        prior_z, _ = torch.split(prior_u, [self.z_dim, self.w_dim], dim=-1)

        post_params_mean, post_params_lv = u_post_params
        prior_params_mean, prior_params_lv = c_prior_params
        post_z_mean, post_w_mean = torch.split(post_params_mean, [self.z_dim, self.w_dim], dim=-1)
        post_z_lv, post_w_lv = torch.split(post_params_lv, [self.z_dim, self.w_dim], dim=-1)
        prior_z_mean, _ = torch.split(prior_params_mean, [self.z_dim, self.w_dim], dim=-1)
        prior_z_lv, _ = torch.split(prior_params_lv, [self.z_dim, self.w_dim], dim=-1)

        capacity_z = min(self.num_steps * 0.005, 0.5)
        capacity_w = min(self.num_steps * 0.005, 0.5)

        kl_z = self.compute_kl(post_z_mean, post_z_lv, prior_z_mean, prior_z_lv)
        kl_w = self.compute_kl(post_w_mean, post_w_lv, self.pw_params[0], self.pw_params[1])

        recon = self.decoder(u)
        log_psz = self.likelihood_s(recon, data)

        log_qts = self.classifier_loss_img(data, targ, k=10, z_sample=z)
        weight = torch.exp(log_qtz - log_qts)

        loss = weight.detach() * (log_psz - log_qtz - torch.abs(kl_z - capacity_z) - torch.abs(kl_w - capacity_w)) + self.classifier_scale_sup * (log_qts + log_psz) - kl_z - self.alpha * compute_feature(z, prior_z)
        
        with torch.no_grad():
            grad_wt = (loss - torch.logsumexp(loss, 0, keepdim=True)).exp()
            if u.requires_grad:
                u.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad)
        return -(grad_wt * loss).mean()

    def run_unsup(self, data, targ=None, k=1):
        bs = data.shape[0]
        u_post_params = self.encoder(data)
        u = dist.Normal(*u_post_params).rsample(torch.Size([k]))
        z, _ = torch.split(u, [self.z_dim, self.w_dim], dim=-1)

        c_prior_params = self.get_mixture_prior_params(self.pseudo_samples[:bs])

        post_params_mean, post_params_lv = u_post_params
        prior_params_mean, prior_params_lv = c_prior_params
        post_z_mean, post_w_mean = torch.split(post_params_mean, [self.z_dim, self.w_dim], dim=-1)
        post_z_lv, post_w_lv = torch.split(post_params_lv, [self.z_dim, self.w_dim], dim=-1)
        prior_z_mean, _ = torch.split(prior_params_mean, [self.z_dim, self.w_dim], dim=-1)
        prior_z_lv, _ = torch.split(prior_params_lv, [self.z_dim, self.w_dim], dim=-1)

        capacity_z = min(self.num_steps * 0.005, 0.5)
        capacity_w = min(self.num_steps * 0.005, 0.5)

        kl_z = self.compute_kl_enum(post_z_mean, post_z_lv, prior_z_mean, prior_z_lv,
                                    z=z, logit_probs=torch.ones((bs, bs), device=data.device))
        kl_w = self.compute_kl(post_w_mean, post_w_lv, self.pw_params[0], self.pw_params[1])

        recon = self.decoder(u)
        log_psz = self.likelihood_s(recon, data)

        loss = self.classifier_scale_unsup * log_psz - torch.abs(kl_z - capacity_z) - torch.abs(kl_w - capacity_w)

        with torch.no_grad():
            grad_wt = (loss - torch.logsumexp(loss, 0, keepdim=True)).exp()
            if u.requires_grad:
                u.register_hook(lambda grad: grad_wt.unsqueeze(-1) * grad)
        return -(grad_wt * loss).mean()


class MVAE(BASE_MVAE):
    def __init__(self, args, pseudo_samples_a, pseudo_samples_b):
        super(MVAE, self).__init__(args, pseudo_samples_a, pseudo_samples_b)
        self.to(args.device)
        self.device = args.device
        self.classifier_scale_sup = 10
        self.classifier_scale_unsup = 10
        self.alpha = args.alpha
        self.a_name = args.a_name
        self.b_name = args.b_name

    def a_to_b(self, data, k=1):
        u = dist.Normal(*self.a_to_z(data)).sample()
        z, w = torch.split(u, [self.z_dim, self.w_dim], dim=-1)
        if k==1:
            w_new = torch.zeros_like(w)
            u_new = torch.cat((z, w_new), dim=-1)
        else:
            w_new = dist.Normal(self.pw_params[0], self.pw_params[1]).sample([k, w.size()[0]]).squeeze(2)
            z = z.unsqueeze(0).expand(k, -1, -1)
            u_new = torch.cat((z, w_new), dim=-1)
        return self.z_to_b(u_new)[0]

    def b_to_a(self, data, k=1):
        u = dist.Normal(*self.b_to_z(data)).sample()
        z, w = torch.split(u, [self.z_dim, self.w_dim], dim=-1)
        if k==1:
            w_new = torch.zeros_like(w)
            u_new = torch.cat((z, w_new), dim=-1)
        else:
            w_new = dist.Normal(self.pw_params[0], self.pw_params[1]).sample([k ,w.size()[0]]).squeeze(2)
            z = z.unsqueeze(0).expand(k, -1, -1)
            u_new = torch.cat((z, w_new), dim=-1)
        return self.z_to_a(u_new)[0]

    def a_to_a(self, data, k=1):
        u = dist.Normal(*self.a_to_z(data)).sample()
        if k != 1:
            z, w = torch.split(u, [self.z_dim, self.w_dim], dim=-1)
            w_new = dist.Normal(self.pw_params[0], self.pw_params[1]).sample([k, w.size()[0]]).squeeze(2)
            z = z.unsqueeze(0).expand(k, -1, -1)
            u = torch.cat((z, w_new), dim=-1)
        return self.z_to_a(u)[0]
        

    def b_to_b(self, data, k=1):
        u = dist.Normal(*self.b_to_z(data)).sample()
        if k != 1:
            z, w = torch.split(u, [self.z_dim, self.w_dim], dim=-1)
            w_new = dist.Normal(self.pw_params[0], self.pw_params[1]).sample([k, w.size()[0]]).squeeze(2)
            z = z.unsqueeze(0).expand(k, -1, -1)
            u = torch.cat((z, w_new), dim=-1)
        return self.z_to_b(u)[0]

    def tsne_plot(self, loader, save_dir, n):
        plt.rc('font',family='Times New Roman')
        plt.rcParams.update({'font.size': 15})
        with torch.no_grad():
            enc_feats1 = []
            labs1 = []
            enc_feats2 = []
            labs2 = []
            for i, (batch_a, batch_b, y) in enumerate(loader):
                labs1.append(y)
                au = dist.Normal(*self.a_to_z(batch_a.to(self.device))).sample()
                az, _ = torch.split(au, [self.z_dim, self.w_dim], dim=-1)
                enc_feats1.append(az.cpu())
                labs2.append(y + n)
                bu = dist.Normal(*self.b_to_z(batch_b.to(self.device))).sample()
                bz, _ = torch.split(bu, [self.z_dim, self.w_dim], dim=-1)
                enc_feats2.append(bz.cpu())
                
            enc_feats1 = torch.cat(enc_feats1, dim=0)
            labs1 = torch.cat(labs1, dim=0)
            enc_feats2 = torch.cat(enc_feats2, dim=0)
            labs2 = torch.cat(labs2, dim=0)
            enc_feats = torch.cat([enc_feats1, enc_feats2], dim=0)
            labs = torch.cat([labs1, labs2], dim=0)

            model_tsne_high = TSNE(n_components=2, random_state=0)

            z_embed = model_tsne_high.fit_transform(enc_feats)
            fig = plt.figure(figsize=(7, 5))
            for ic in range(n):
                ind_class = np.where(labs == ic)
                color = plt.cm.tab20(2*ic)
                plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color, label=self.a_name + "_%i" % ic)
            for ic in range(n):
                ind_class = np.where(labs == ic + n)
                color = plt.cm.tab20(2*ic+1)
                plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color, label=self.b_name + "_%i" % ic)
            plt.rcParams.update({'font.size': 9})
            plt.legend(loc='center right')
            name = 'embedding_z.png'
            if save_dir is not None:
                fig.savefig(os.path.join(save_dir, name))


