# contrative VAE - z is also a distribution
# training with random atrophy each epoch
import os, fnmatch, time, datetime
import numpy as np, random
import pandas as pd
import matplotlib.pyplot as plt
import sklearn 
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.linalg import sqrtm
from scipy.spatial import distance
import itertools
from collections import OrderedDict
import torch.distributions as D

if torch.cuda.is_available():
    DEVICE='cuda'
else:
    DEVICE='cpu'
#DEVICE='cpu'
seed=42
torch.cuda.manual_seed(seed)
torch.manual_seed(42)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0, 0.1)
        if m.bias is not None:
            m.bias.data.normal_(0, 0.1)

class DecayLR:
    def __init__(self, epochs, offset, decay_epochs):
        epoch_flag = epochs - decay_epochs
        assert (epoch_flag > 0), "Decay must start before the training session ends!"
        self.epochs = epochs
        self.offset = offset
        self.decay_epochs = decay_epochs

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (
                self.epochs - self.decay_epochs)

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

# calculate mmd
def mmd(x, y, gammas, device):
    gammas = gammas.to(device)

    cost = torch.mean(gram_matrix(x, x, gammas=gammas)).to(device)
    cost += torch.mean(gram_matrix(y, y, gammas=gammas)).to(device)
    cost -= 2 * torch.mean(gram_matrix(x, y, gammas=gammas)).to(device)

    if cost < 0:
        return torch.tensor(0).to(device)
    return cost

def gram_matrix(x, y, gammas):
    gammas = gammas.unsqueeze(1)
    pairwise_distances = torch.cdist(x, y, p=2.0)

    pairwise_distances_sq = torch.square(pairwise_distances)
    tmp = torch.matmul(gammas, torch.reshape(pairwise_distances_sq, (1, -1)))
    tmp = torch.reshape(torch.sum(torch.exp(-tmp), 0), pairwise_distances_sq.shape)
    return tmp

## calulate wasserstein distance assuming independent ROIs
def cal_w_distance(mean_1,mean_2,std_1,std_2):
    distance=np.sum(np. square(mean_1-mean_2) + np.square(std_1) + np.square(std_2) - (2*std_1*std_2))
    return  distance
    
## calulate wasserstein distance without assuming independent ROIs
def cal_w_distance_with_cov(mean_1,mean_2,std_1,std_2):
    distance=0
    distance=np.sum(np.square(mean_1-mean_2))+np.trace(std_1+std_2-2*sqrtm(sqrtm(std_2).dot(std_1).dot(sqrtm(std_2))))
    return  distance.real

## return mean wasserstein distance and wasserstein distance of each mapping direction
def cal_validate_distance( mean,std,eva_data,independent):
    mean_eva = np.mean(eva_data, axis=0)
    if independent:
        std_eva = np.std(eva_data, axis=0)
        distance = cal_w_distance(mean, mean_eva, std, std_eva)
    else:
        std_eva = np.cov(np.transpose(eva_data))
        distance = cal_w_distance_with_cov(mean,mean_eva, std, std_eva)
    return distance

def eval_w_distances_forward(samplea, sampleb,   independent=True):
    # samplea: dataset 1 # numpy, sampleb: dataset 2 #numpy
    sampleb=sampleb.to('cpu').numpy()
    mean = np.mean(sampleb, axis=0) # mean of each ROI
    if independent:
        std = np.std(sampleb, axis=0) 
    else:
        std = np.cov(np.transpose(sampleb)) # covariance of each ROI wrt every other ROI 
    eva_data = samplea.to('cpu').numpy()
    w_distance = cal_validate_distance( mean,std,eva_data,independent)
    return w_distance 

class Generator(nn.Module):
    def __init__(self,nROI,nLatent_z,nLatent_s):
        super(Generator, self).__init__()
        self.nLatent_s=nLatent_s
        self.nLatent_z=nLatent_z
        self.modelenc_z=nn.Sequential(nn.Linear((nROI),int((nROI)/2)),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(int((nROI)/2), nLatent_z*2))
        self.modelenc_s=nn.Sequential(nn.Linear((nROI),int((nROI)/2)),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(int((nROI)/2), nLatent_s*2))
        self.modeldec=nn.Sequential(nn.Linear((nLatent_z+nLatent_s),int((nROI)/2)),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    #nn.Dropout(p=0.2),
                                    nn.Linear(int((nROI)/2), nROI))

    def forward(self, input_x, input_y):
        # x is target
        # y is background
        device  = input_x.device
        mu_sx, logsigma_sx = self.modelenc_s(input_x).chunk(2, dim=-1)
        
        sx = torch.mul(torch.randn(input_x.shape[0],self.nLatent_s, device=device), abs(logsigma_sx.exp())) + mu_sx  # sample salient latent for target samples
        sy = torch.zeros(input_y.shape[0],self.nLatent_s,device=device) # salient to 0 for background samples
        
        mu_zy, logsigma_zy = self.modelenc_z(input_y).chunk(2, dim=-1)
        mu_zx, logsigma_zx = self.modelenc_z(input_x).chunk(2, dim=-1) 

        zx = torch.mul(torch.randn(input_x.shape[0],self.nLatent_z, device=device), abs(logsigma_zx.exp())) + mu_zx # sample common latent  for target samples
        zy = torch.mul(torch.randn(input_y.shape[0],self.nLatent_z, device=device), abs(logsigma_zy.exp())) + mu_zy # sample common latent  for background samples
        # concat sampled latents
        latentconcatx = torch.cat([zx, sx], dim=1)
        latentconcaty = torch.cat([zy, sy], dim=1)
        # reconstruct inputs
        x_recon = self.modeldec(latentconcatx)
        y_recon = self.modeldec(latentconcaty)
        return x_recon, y_recon, mu_sx, logsigma_sx, zx, zy, mu_zx, logsigma_zx, mu_zy, logsigma_zy

class ContrastiveVAE(object):
    def __init__(self):
        self.opt = None

        ##### definition of all netwotks
        self.G_forward = None
        self.D_forward= None

        ##### definition of all optimizers
        # Optimizers
        self.optimizer_G = None
        self.optimizer_D_forward = None
        
        ##### definition of all criterions
        self.criterionRecon = F.mse_loss # l1 loss
        self.mmd = mmd #compute mmd 

        ##### define loss hyperparameters
        self.mmd_lambda = None
        self.kld_lambda = None

        ##### define scheduler variables
        self.lr_lambda = None
        self.lr_scheduler_G = None
        self.lr_scheduler_D_forward = None
        self.lr_scheduler_D_backward = None
        
        
    def create(self, opt):
        self.opt = dotdict({})
        for key in opt.keys():
            self.opt[key] = opt[key]

        ## define all netwotks
        # Generate control data and discriminate between real and generated
        self.G_forward = Generator(self.opt.nROI, self.opt.nLatent_z, self.opt.nLatent_s).to(self.opt.DEVICE)

        ## initiate networks
        self.G_forward.apply(weights_init)


        ## define all optimizers
        self.optimizer_G = torch.optim.Adam(self.G_forward.parameters(),
                                        lr=self.opt.lr, betas=(self.opt.beta1, 0.999),weight_decay=2.5*1e-3)

        
        ## define all schedulers
        if self.opt.scheduler:
            self.lr_lambda = DecayLR(self.opt.epochs, 0, self.opt.decay_epochs).step
            self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(self.optimizer_G, lr_lambda=self.lr_lambda)

        ## define hyperparameters
        self.mmd_lambda = self.opt.mmd_lambda
        self.kld_lambda = self.opt.kld_lambda



    def train_instance(self, real_x, real_y, reduction='sum'):
        ## Input to networks
        assert len(real_x.shape)==2 and real_x.shape[-1]==self.opt.nROI, 'real_x should be a 2D tensor with last dim equal to number of ROIs'
        assert len(real_y.shape)==2 and real_y.shape[-1]==self.opt.nROI, 'real_y should be a 2D tensor with last dim equal to number of ROIs'

        real_x = real_x.to(self.opt.DEVICE)
        real_y = real_y.to(self.opt.DEVICE)

        ## set networks to train mode
        self.G_forward.train()        
        
        # Generate fake samples
        fake_x, fake_y, mu_sx, logsigma_sy, zx, zy, mu_zx, logsigma_zx, mu_zy, logsigma_zy = self.G_forward(real_x, real_y) # if real_x is target data, real_y is background

        gammas = torch.FloatTensor([10 ** x for x in range(-3, 3, 1)])     
        
        ###--------------------------------------
        # (1)MMD loss for common latent space z
        ###--------------------------------------
        dist_loss_z = self.mmd(zx,zy,gammas,self.opt.DEVICE) #
        
        #-----------------
        # (2) reconstruction loss MSE
        #------------------
        recon_xloss = self.criterionRecon(fake_x, real_x, reduction=reduction)*self.opt['nROI']
        recon_yloss = self.criterionRecon(fake_y, real_y, reduction=reduction)*self.opt['nROI']

        #----------------
        #3)KL divergence loss on common and salient latent space
        #-----------------
        kldiv_s = 1 + 2*logsigma_sx - torch.square(mu_sx) - torch.square(logsigma_sx.exp()) # salient space
        kldiv_zxy = (1 + 2*logsigma_zx - torch.square(mu_zx) - torch.square(logsigma_zx.exp()))*0.5
        kldiv_zxy += (1 + 2*logsigma_zy - torch.square(mu_zy) - torch.square(logsigma_zy.exp()))*0.5
        kldiv_s = -0.5*torch.sum(kldiv_s)
        kldiv_s /= real_x.shape[0]
        kldiv_zxy = -0.5*torch.sum(kldiv_zxy)
        kldiv_zxy /= real_x.shape[0]
                                                                   
        loss_G = recon_xloss + kldiv_s + recon_yloss + self.kld_lambda*kldiv_zxy + self.mmd_lambda*dist_loss_z # in semi synth lambda=10, for real lambda=100

        ## update weights of forward discriminator
        self.optimizer_G.zero_grad()
        loss_G.backward()
        self.optimizer_G.step()

        losses = OrderedDict([('kldz_loss', kldiv_zxy.item()), ('Combined_loss', loss_G.item()), ('Generator_loss_list',[recon_xloss.item(), recon_yloss.item(),  kldiv_s.item(), dist_loss_z.item(),0])])
        
        return losses
