from __future__ import print_function
import os
import argparse
import numpy as np
import random

import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn.init as init
from torchvision.utils import make_grid, save_image
from itertools import combinations, permutations

import data_extract
from disentanglement_lib.evaluation.metrics import factor_vae 

parser = argparse.ArgumentParser(description='Linear-Enc-Decom VAE Experiment')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='F',
                    help='Learning rate')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=500, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--z_dim', type=int, default=10, metavar='N',
                    help='Number of res_vaes')
parser.add_argument('--output_dir', type=str, default='VAE_Ensemble', metavar='S',
                    help='Direcyroy to save the results')
parser.add_argument('--limit', type=int, default=3, metavar='N',
                    help='limit for interpolation')
parser.add_argument('--inter', type=float, default=0.33, metavar='F',
                    help='interval for interpolation')
parser.add_argument('--beta', type=float, default=1.0, metavar='F',
                    help='beta to compare with beta-VAE')
parser.add_argument('--dataset', type=str, default='dsprites', metavar='S',
                    help='dataset for the experiment')
parser.add_argument('--runs', type=int, default=1, metavar='N',
                    help='how many times to run the experiment')
parser.add_argument('--module_num', type=int, default=2, metavar='N',
                    help='Number of vaes to include')
parser.add_argument('--gamma', type=float, default=1.0, metavar='F',
                    help='Regularization for linear transformations')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
#torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")

kwargs = {'num_workers': 3, 'pin_memory': True} if args.cuda else {}
os.environ["CUDA_VISIBLE_DEVICES"]="0"

dataset = args.dataset
if dataset == 'dsprites':
    nc = 1
else:
    nc = 3
train_loader = data_extract.return_data_loader(dataset)
module_num = args.module_num
z_dim = args.z_dim
main_out_dir = 'Results/'+args.output_dir
beta = args.beta
gamma = args.gamma
runs = args.runs

from disentanglement_lib.data.ground_truth import dsprites 
gt_dataset = dsprites.DSprites([1, 2, 3, 4, 5])

# Convert the model to a function to use disentangle_lib evaluation
def make_representor(model, cuda=None):
    cuda = torch.cuda.is_available() if cuda is None else cuda
    model = model.cuda() if cuda else model.cpu()     
    def _represent(x):
        assert isinstance(x, np.ndarray),\
                "Input to the representation function must be a ndarray."  
        assert x.ndim == 4, \
                "Input to the representation function must be a four dimensional NHWC tensor."
        # Convert from NHWC to NCHW
        x = np.moveaxis(x, 3, 1)
        # Convert to torch tensor and evaluate
        x = torch.from_numpy(x).float().to('cuda' if cuda else 'cpu')
        with torch.no_grad():
            _, y_all, _, _, _, _ = model(x)
            y = y_all[0]
        assert y.ndim == 2, \
                "The returned output from the representor must be two dimensional (NC)."
        return y.cpu().numpy() 

    return _represent 


def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)

def normal_init(m, mean, std):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        m.weight.data.normal_(mean, std)
        if m.bias.data is not None:
            m.bias.data.zero_()
    elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
        m.weight.data.fill_(1)
        if m.bias.data is not None:
            m.bias.data.zero_()


def load_checkpoint(self, filename):
	file_path = os.path.join(self.ckpt_dir, filename)
	if os.path.isfile(file_path):
		checkpoint = torch.load(file_path)
		self.global_iter = checkpoint['iter']
		self.win_recon = checkpoint['win_states']['recon']
		self.win_kld = checkpoint['win_states']['kld']
		self.win_var = checkpoint['win_states']['var']
		self.win_mu = checkpoint['win_states']['mu']
		self.net.load_state_dict(checkpoint['model_states']['net'])
		self.optim.load_state_dict(checkpoint['optim_states']['optim'])
		print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
	else:
		print("=> no checkpoint found at '{}'".format(file_path))


class View(nn.Module):
    def __init__(self, size):
        super(View, self).__init__()
        self.size = size
        
    def forward(self, tensor):
        return tensor.view(self.size)

class RepresentationExtractor(nn.Module):
    VALID_MODES = ['mean', 'sample']

    def __init__(self, srvae, mode='mean'):
        super(RepresentationExtractor, self).__init__()
        assert mode in self.VALID_MODES, f'`mode` must be one of {self.VALID_MODES}'
        self.encoder = srvae
        self.mode = mode

    def forward(self, x):
        _, mu, logvar, _ = self.encoder(x)
        if self.mode == 'mean':
            return mu
        elif self.mode == 'sample':
            return self.reparametrize(mu, logvar)
        else:
            raise NotImplementedError

    @staticmethod
    def reparametrize(mu, logvar):
        #std = logvar.div(2).exp()                                              
        #eps = Variable(std.data.new(std.size()).normal_())
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


class VAE(nn.Module):
    def __init__(self, z_dim=10, nc=1):
        super(VAE, self).__init__()
        self.z_dim = z_dim
        self.nc = nc
        self.encoder = nn.Sequential(
                nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
                nn.ReLU(True),
                nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
                nn.ReLU(True),
                nn.Conv2d(32, 64, 4, 2, 1),          # B,  64,  8,  8
                nn.ReLU(True),
                nn.Conv2d(64, 64, 4, 2, 1),          # B,  64,  4,  4
                nn.ReLU(True),
                nn.Conv2d(64, 256, 4, 1),            # B, 256,  1,  1
                nn.ReLU(True),
                View((-1, 256*1*1)),                 # B, 256
                nn.Linear(256, self.z_dim*2),             # B, z_dim*2
            )
        self.decoder = nn.Sequential(
                nn.Linear(self.z_dim, 256),               # B, 256
                View((-1, 256, 1, 1)),               # B, 256,  1,  1
                nn.ReLU(True),
                nn.ConvTranspose2d(256, 64, 4),      # B,  64,  4,  4
                nn.ReLU(True),
                nn.ConvTranspose2d(64, 64, 4, 2, 1), # B,  64,  8,  8
                nn.ReLU(True),
                nn.ConvTranspose2d(64, 32, 4, 2, 1), # B,  32, 16, 16
                nn.ReLU(True),
                nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
                nn.ReLU(True),
                nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
            )
        self.weight_init()

    def weight_init(self):
        for block in self._modules:
            for m in self._modules[block]:
                kaiming_init(m)
                #normal_init(m,0,1)
    
    def forward(self, x):
        #x_res = x #Res-VAE
        distribution = self.encoder(x)
        mu = distribution[:, :self.z_dim]
        logvar = distribution[:, self.z_dim:]
        z = RepresentationExtractor.reparametrize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar, distribution

    def decode(self, z_input):
        x_recon = self.decoder(z_input) 
        return x_recon 



class VAE_Ensemble(nn.Module):
    def __init__(self, m=2, z_dim=10, nc=1 ):
        super(VAE_Ensemble, self).__init__()
        self.m = m
        self.z_dim = z_dim
        self.nc = nc
        self.vaes = nn.ModuleList([VAE(self.z_dim, self.nc) for i in range(self.m)])
        self.one_linears = nn.ModuleList([nn.Linear(2*self.z_dim, 2*self.z_dim)
                                          for i in range(self.m-1)])
        self.linears = nn.ModuleList([self.one_linears for i in range(self.m)]) 
        # self.linears[i][j] has the mapping of f(z_i)=z_j

    def forward(self, x):
        r = []
        m = []
        l = []
        d = []
        d_lt = []
        r_lt = []
        for i in range(self.m):
            rec, mu, logvar, dist = self.vaes[i](x)
            r.append(rec)
            m.append(mu)
            l.append(logvar)
            d.append(dist)
            temp = []
            for j in range(self.m-1):
                dist_lt = self.linears[i][j](dist)
                mu_lt = dist_lt[:, :self.z_dim]
                logvar_lt = dist_lt[:, self.z_dim:]
                z_lt = RepresentationExtractor.reparametrize(mu_lt, logvar_lt)
                if j>=i:
                    rec_lt = self.vaes[j+1].decoder(z_lt)
                else:
                    rec_lt = self.vaes[j].decoder(z_lt)
                temp.append(dist_lt)
                r_lt.append(rec_lt)
            d_lt.append(temp)

        return r, m, l, d, d_lt, r_lt


def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)
    
    return total_kld, dimension_wise_kld, mean_kld


def loss_function(r, m, l, d, d_lt, rec_lt, x, dataset):
    batch_size = x.size(0)
    rec_L = 0
    kl_L = 0
    lr_L = 0
    for i in range(len(r)):
        if dataset == 'dsprites':
            rec_L += F.binary_cross_entropy_with_logits(r[i], x, size_average=False).div(batch_size)  
        else:
            r[i] = F.sigmoid(r[i])
            rec_L += F.mse_loss(r[i], x, size_average=False).div(batch_size)
    for i in range(len(rec_lt)):
        if dataset == 'dsprites':
            rec_L += F.binary_cross_entropy_with_logits(rec_lt[i], x, size_average=False).div(batch_size)  
        else:
            rec_lt[i] = F.sigmoid(rec_lt[i])
            rec_L += F.mse_loss(rec_lt[i], x, size_average=False).div(batch_size)

    for i in range(len(m)):
        kl_b, _, _ = kl_divergence(m[i], l[i])
        kl_L += kl_b

    for i in range(len(d)):
        for j in range(len(d)-1):
            lr_L += F.mse_loss(d_lt[i][j], d[j], size_average=False).div(batch_size)
   
    return kl_L+rec_L+gamma*lr_L, rec_L, kl_L, lr_reg


def train(model, epoch):

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    model.train()
    train_loss = 0
    valid_training = True
    for batch_idx, batch in enumerate(train_loader):
        if dataset == 'dsprites':
            data = batch[0]
        else:
            data = batch[1]

        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, logvar, dist, dist_lt, rec_lt = model(data)
        loss, l1, l2, l3 = loss_function(recon, mu, logvar, dist, dist_lt, rec_lt, data, dataset)
        if torch.isnan(loss):
            valid_training = False
            print("nan encouted! Retraining starts...")
            break
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{}({:.0f}%)]\tLoss:{:.6f}\t{:.6f}\t{:.6f}\t{:.6f}'.format(epoch,
                                                                          batch_idx*len(data),
                                                                          len(train_loader.dataset),
                                                                          100.*batch_idx/len(train_loader),
                                                                          loss.item(),
                                                                          l1.item(),
                                                                          l2.item(),
                                                                          l3.item()))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader)))
    return valid_training, l1, l2, l3


if __name__ == "__main__":

    overwrite = True
    m_run = []
    for run in range(runs):
        valid = False
        m = []
        for epoch in range(1, args.epochs + 1):
            score = {}
            if not valid:
                print("Either initial epoch or not valid model encountered, restart training!")
                model = Mul_VAE(module_num, z_dim, nc).to(device)
                model = torch.nn.DataParallel(model)

            valid, rec, kl, lr = train(model, epoch)

            fn = make_representor(model)
            factor_vae_score = factor_vae.compute_factor_vae(gt_dataset, fn,
                                                  np.random.RandomState(0), 64,
                                                  10000, 5000, 10000)
            print(factor_vae_score)
            score['factor_vae_metric'] = factor_vae_score
            output_dir = os.path.join(main_out_dir, 'Run_'+str(run), str(epoch))
            os.makedirs(output_dir, exist_ok=True) 
            model_states = {'net':model.state_dict(),}
            measure_states = {'score':score, 'rec': rec, 'kl': kl, 'lr':lr}
            states = {'iter':epoch,                             
                      'win_states': measure_states,                                         
                      'model_states':model_states,}                                        
            file_path = os.path.join(output_dir, str(epoch)+'.ckpt')
            with open(file_path, mode='wb+') as f:
                torch.save(states, f) 
                print("=> saved checkpoint '{}' (epoch {})".format(file_path, epoch))
            m.append(score)
        m_run.append(m)
    print(m_run)

