import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Function
import pickle
import os
import time
from PIL import Image

import sys
sys.path.append("../EinsumNetworks/src")
sys.path.append("../integer_discrete_flows")

from EinsumNetwork import Graph, EinsumNetwork
import utils
from optimization.training import train, evaluate
import models.Model as Model


def print_cuda_mem():
    print("cuda:0", torch.cuda.memory_allocated(device=torch.device("cuda:0"))/1024./1024.)
    print("cuda:1", torch.cuda.memory_allocated(device=torch.device("cuda:1"))/1024./1024.)


class ReverseGrad(Function):
    @staticmethod
    def forward(ctx, x):
        return x
    
    @staticmethod
    def backward(ctx, grad):
        return -grad


def my_train(epoch, flow_model, einsum_models, train_loader, optimizer, args):
    flow_model.train()
    
    train_loss = np.zeros(len(train_loader))
    train_bpd = np.zeros(len(train_loader))
    
    num_data = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, *args.input_size)
        data = data.to(args.device)
        
        if not args.no_pc:
            log_likelihood, true_ll = train_step_logistic(flow_model, einsum_models, data, optimizer)
            # log_likelihood, true_ll = train_step_gaussian(flow_model, einsum_models, data, optimizer)
            # log_likelihood, true_ll = train_step_test(flow_model, einsum_models, data, optimizer)
        else:
            if args.n_mixtures > 1:
                log_likelihood, true_ll = train_step_baseline_mixture(flow_model, einsum_models, data, optimizer)
            else:
                log_likelihood, true_ll = train_step_baseline(flow_model, einsum_models, data, optimizer)
        bpd = -true_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        train_loss[batch_idx] = log_likelihood
        train_bpd[batch_idx] = bpd
        
        num_data += data.size(0)
        
        perc = 100. * batch_idx / len(train_loader)
        print("\r                                                                             ", end = "")
        tmp = '\rEpoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] LL: {:.2f} bpd: {:.3f}'
        print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, log_likelihood, bpd), end = "")
        
    print("")
        
    print('====> Epoch: {:3d} Average train loss: {:.4f} Average bpd: {:.3f}'.format(
          epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
    
    return train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)


def my_train_multi_gpus(epoch, flow_model, einsum_models, train_loader, optimizer, args):
    flow_model.train()
    
    train_loss = np.zeros(len(train_loader))
    train_bpd = np.zeros(len(train_loader))
    
    num_data = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, *args.input_size)
        data = data.to(args.flow_device)
        
        if not args.no_pc:
            log_likelihood, true_ll = train_step_logistic_multi_gpus(flow_model, einsum_models, data, optimizer, args)
        else:
            raise NotImplementedError()
        bpd = -true_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        train_loss[batch_idx] = log_likelihood
        train_bpd[batch_idx] = bpd
        
        num_data += data.size(0)
        
        perc = 100. * batch_idx / len(train_loader)
        print("\r                                                                             ", end = "")
        tmp = '\rEpoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] LL: {:.2f} bpd: {:.3f}'
        print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, log_likelihood, bpd), end = "")
        
    print("")
        
    print('====> Epoch: {:3d} Average train loss: {:.4f} Average bpd: {:.3f}'.format(
          epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
    
    return train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)


def my_train_multi_gpus_sgd(epoch, flow_model, einsum_models, train_loader, optimizer, args):
    flow_model.train()
    
    train_loss = np.zeros(len(train_loader))
    train_bpd = np.zeros(len(train_loader))
    
    num_data = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, *args.input_size)
        data = data.to(args.flow_device)
        
        if not args.no_pc:
            log_likelihood, true_ll = train_step_logistic_multi_gpus_sgd(flow_model, einsum_models, data, optimizer, args)
        else:
            raise NotImplementedError()
        bpd = -true_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        train_loss[batch_idx] = log_likelihood
        train_bpd[batch_idx] = bpd
        
        num_data += data.size(0)
        
        perc = 100. * batch_idx / len(train_loader)
        print("\r                                                                             ", end = "")
        tmp = '\rEpoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] LL: {:.2f} bpd: {:.3f}'
        print(tmp.format(epoch, num_data, len(train_loader.sampler), perc, log_likelihood, bpd), end = "")
        
    for einet in einsum_models:
        einet.adamax_scheduler.step()
        
    print("")
        
    print('====> Epoch: {:3d} Average train loss: {:.4f} Average bpd: {:.3f}'.format(
          epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)))
    
    return train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)


def my_evaluate(epoch, flow_model, einsum_models, val_loader, args):
    flow_model.eval()
    
    val_bpd = np.zeros(len(val_loader))
    
    num_data = 0
    
    for batch_idx, (data, _) in enumerate(val_loader):
        data = data.view(-1, *args.input_size)
        data = data.to(args.device)
        
        if not args.no_pc:
            eval_ll = eval_step_logistic(flow_model, einsum_models, data)
        else:
            if args.n_mixtures > 1:
                eval_ll = eval_step_baseline_mixture(flow_model, einsum_models, data)
            else:
                eval_ll = eval_step_baseline(flow_model, einsum_models, data)
        bpd = -eval_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        val_bpd[batch_idx] = bpd
        
        num_data += data.size(0)
        
        perc = 100. * batch_idx / len(val_loader)
        print("\r                                                                             ", end = "")
        tmp = '\r[eval] Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] bpd: {:.3f}'
        print(tmp.format(epoch, num_data, len(val_loader.sampler), perc, bpd), end = "")
        
    print("")
    
    print('====> [eval] Epoch: {:3d} Average bpd: {:.3f}'.format(epoch, val_bpd.sum() / len(val_loader)))
    
    return val_bpd.sum() / len(val_loader)


def my_evaluate_multi_gpus(epoch, flow_model, einsum_models, val_loader, args):
    flow_model.eval()
    
    val_bpd = np.zeros(len(val_loader))
    
    num_data = 0
    
    for batch_idx, (data, _) in enumerate(val_loader):
        data = data.view(-1, *args.input_size)
        data = data.to(args.flow_device)
        
        if not args.no_pc:
            eval_ll = eval_step_logistic_multi_gpus(flow_model, einsum_models, data, args)
        else:
            raise NotImplementedError()
        bpd = -eval_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        val_bpd[batch_idx] = bpd
        
        num_data += data.size(0)
        
        perc = 100. * batch_idx / len(val_loader)
        print("\r                                                                             ", end = "")
        tmp = '\r[eval] Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] bpd: {:.3f}'
        print(tmp.format(epoch, num_data, len(val_loader.sampler), perc, bpd), end = "")
        
    print("")
    
    print('====> [eval] Epoch: {:3d} Average bpd: {:.3f}'.format(epoch, val_bpd.sum() / len(val_loader)))
    
    return val_bpd.sum() / len(val_loader)


def my_evaluate_multi_gpus_sgd(epoch, flow_model, einsum_models, val_loader, args):
    flow_model.eval()
    
    val_bpd = np.zeros(len(val_loader))
    
    num_data = 0
    
    for batch_idx, (data, _) in enumerate(val_loader):
        data = data.view(-1, *args.input_size)
        data = data.to(args.flow_device)
        
        if not args.no_pc:
            eval_ll = -eval_step_logistic_multi_gpus(flow_model, einsum_models, data, args)
        else:
            raise NotImplementedError()
        bpd = -eval_ll * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        val_bpd[batch_idx] = bpd
        
        num_data += data.size(0)
        
        perc = 100. * batch_idx / len(val_loader)
        print("\r                                                                             ", end = "")
        tmp = '\r[eval] Epoch: {:3d} [{:5d}/{:5d} ({:2.0f}%)] bpd: {:.3f}'
        print(tmp.format(epoch, num_data, len(val_loader.sampler), perc, bpd), end = "")
        
    print("")
    
    print('====> [eval] Epoch: {:3d} Average bpd: {:.3f}'.format(epoch, val_bpd.sum() / len(val_loader)))
    
    return val_bpd.sum() / len(val_loader)


def train_step_logistic(flow_model, einsum_models, data, flow_opt):
    flow_opt.zero_grad()
    
    ## Forward flow model ##
    pz, z, pys, ys, _ = flow_model.forward_only(data)
    
    ## Latent codes ##
    latent_codes = [
        [z, pz[0], pz[1]], [ys[0], pys[0][0], pys[0][1]], [ys[1], pys[1][0], pys[1][1]]
    ]
    
    ## Forward einsum models ##
    sum_ll = 0.0
    for idx, (einet, zs) in enumerate(zip(einsum_models, latent_codes)):
        n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
        h = zs[0].view(n, -1)
        h_mean = zs[1].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        h_logscale = zs[2].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        
        h = ReverseGrad.apply(h)
        h_mean = ReverseGrad.apply(h_mean)
        h_logscale = ReverseGrad.apply(h_logscale)
        
        ll_sample = einet.forward_with_grad2(h, h_mean, h_logscale)
        
        log_likelihood = ll_sample.mean()
        log_likelihood.backward(retain_graph = True)
        
        einet.em_process_batch()
        
        sum_ll += log_likelihood.detach().cpu().numpy()
        
    flow_opt.step()
    
    return sum_ll, sum_ll


def depth_to_space1(x, a, b):
    xs = x.size()
    x = x.view(xs[0], 1, a, b, xs[2], xs[3])
    x = x.permute((0, 1, 4, 2, 5, 3)).contiguous()
    x = x.view(xs[0], 1, xs[2] * a, xs[3] * b)
    return x

def depth_to_space2(x, a, b):
    xs = x.size()
    x = x.view(xs[0], xs[1] // a // b, a, b, xs[2], xs[3])
    x = x.permute((0, 4, 2, 5, 3, 1)).contiguous()
    x = x.view(xs[0], xs[2] * a, xs[3] * b, xs[1] // a // b)
    return x


def train_step_logistic_multi_gpus(flow_model, einsum_models, data, flow_opt, args, use_adamax = False):
    flow_opt.zero_grad()
    
    ## Forward flow model ##
    pz, z, pys, ys, _ = flow_model.forward(data, forward_only = True)
    
    ## Latent codes ##
    latent_codes = [
        [z, pz[0], pz[1]], [ys[0], pys[0][0], pys[0][1]], [ys[1], pys[1][0], pys[1][1]]
    ]
    
    ## Pre-process latent codes ##
    # aaa = [6, 2, 3]
    # bbb = [8, 3, 4]
    for idx in range(len(latent_codes)):
        zs = latent_codes[idx]
        # print(zs[0].size(), zs[1].size(), zs[2].size())
        
        n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
        h = zs[0].view(n, -1)
        h_mean = zs[1].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        h_logscale = zs[2].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        
        # n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
        # h = depth_to_space1(zs[0], aaa[idx], bbb[idx]).view(n, -1)
        # h_mean = depth_to_space2(zs[1], aaa[idx], bbb[idx]).reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        # h_logscale = depth_to_space2(zs[2], aaa[idx], bbb[idx]).reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        # print(zs[0].size(), zs[1].size())
        
        # n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
        # zs0s = zs[0].size()
        # zs1s = zs[1].size()
        # h = zs[0].view(n, -1)
        # h_mean = zs[1].view(n, zs0s[1], zs1s[1] // zs0s[1], zs1s[2], zs1s[3]).permute(0, 1, 3, 4, 2).contiguous().view(n, k, zs1s[1] // zs0s[1])
        # h_logscale = zs[2].view(n, zs0s[1], zs1s[1] // zs0s[1], zs1s[2], zs1s[3]).permute(0, 1, 3, 4, 2).contiguous().view(n, k, zs1s[1] // zs0s[1])
        
        latent_codes[idx] = [h, h_mean, h_logscale]
    
    latent_grads = [[], [], []]
    
    ### shapes = [(4 * 6, 4 * 8), (16 * 2, 16 * 3), (8 * 3, 8 * 4)]
    
    ## Forward einsum models ##
    sum_ll = 0.0
    for idx, (einet, zs) in enumerate(zip(einsum_models, latent_codes)):
        h = zs[0].detach().to(args.cpu).clone().to(args.einets_device)
        h_mean = zs[1].detach().to(args.cpu).clone().to(args.einets_device)
        h_logscale = zs[2].detach().to(args.cpu).clone().to(args.einets_device)
        h.requires_grad = True
        h_mean.requires_grad = True
        h_logscale.requires_grad = True
        
        ll_sample = einet.forward_with_grad2(h, h_mean, h_logscale)
        
        log_likelihood = ll_sample.mean()
        log_likelihood.backward()
        
        einet.em_process_batch()
        
        sum_ll += log_likelihood.detach().cpu().numpy()
        
        latent_grads[idx] = [
            h.grad.data.detach().to(args.flow_device),
            h_mean.grad.data.detach().to(args.flow_device),
            h_logscale.grad.data.detach().to(args.flow_device)
        ]
        h.grad = None
        h_mean.grad = None
        h_logscale.grad = None
        
    ## Backward flow model ##
    n = latent_codes[0][0].size(0)
    all_latents = torch.cat(
        (latent_codes[0][0].reshape(n, -1), latent_codes[0][1].reshape(n, -1), latent_codes[0][2].reshape(n, -1),
         latent_codes[1][0].reshape(n, -1), latent_codes[1][1].reshape(n, -1), latent_codes[1][2].reshape(n, -1),
         latent_codes[2][0].reshape(n, -1), latent_codes[2][1].reshape(n, -1), latent_codes[2][2].reshape(n, -1)),
        dim = 1
    )
    all_grads = torch.cat(
        (latent_grads[0][0].reshape(n, -1), latent_grads[0][1].reshape(n, -1), latent_grads[0][2].reshape(n, -1),
         latent_grads[1][0].reshape(n, -1), latent_grads[1][1].reshape(n, -1), latent_grads[1][2].reshape(n, -1),
         latent_grads[2][0].reshape(n, -1), latent_grads[2][1].reshape(n, -1), latent_grads[2][2].reshape(n, -1)),
        dim = 1
    )
    
    all_latents.backward(-all_grads)
        
    flow_opt.step()
    
    return sum_ll, sum_ll


def train_step_logistic_multi_gpus_sgd(flow_model, einsum_models, data, flow_opt, args, use_adamax = False):
    flow_opt.zero_grad()
    
    ## Forward flow model ##
    pz, z, pys, ys, _ = flow_model.forward(data, forward_only = True)
    
    ## Latent codes ##
    latent_codes = [
        [z, pz[0], pz[1]], [ys[0], pys[0][0], pys[0][1]], [ys[1], pys[1][0], pys[1][1]]
    ]
    
    ## Pre-process latent codes ##
    for idx in range(len(latent_codes)):
        zs = latent_codes[idx]
        n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
        h = zs[0].view(n, -1)
        h_mean = zs[1].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        h_logscale = zs[2].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
        
        latent_codes[idx] = [h, h_mean, h_logscale]
        
    latent_grads = [[], [], []]
    
    ## Forward einsum models ##
    sum_ll = 0.0
    for idx, (einet, zs) in enumerate(zip(einsum_models, latent_codes)):
        h = zs[0].detach().to(args.cpu).clone().to(args.einets_device)
        h_mean = zs[1].detach().to(args.cpu).clone().to(args.einets_device)
        h_logscale = zs[2].detach().to(args.cpu).clone().to(args.einets_device)
        h.requires_grad = True
        h_mean.requires_grad = True
        h_logscale.requires_grad = True
        
        einet.adamax_optimizer.zero_grad()
        
        ll_sample = einet.forward_with_grad2(h, h_mean, h_logscale)
        # print("ll_sample", ll_sample.device)
        log_likelihood = -1.0 * ll_sample.mean()
        log_likelihood.backward()
        # print("alsl")
        einet.adamax_optimizer.step()
        einet.apply_reparam()
        
        sum_ll += log_likelihood.detach().cpu().numpy()
        
        latent_grads[idx] = [
            h.grad.data.detach().to(args.flow_device),
            h_mean.grad.data.detach().to(args.flow_device),
            h_logscale.grad.data.detach().to(args.flow_device)
        ]
        h.grad = None
        h_mean.grad = None
        h_logscale.grad = None
        
    ## Backward flow model ##
    n = latent_codes[0][0].size(0)
    all_latents = torch.cat(
        (latent_codes[0][0].reshape(n, -1), latent_codes[0][1].reshape(n, -1), latent_codes[0][2].reshape(n, -1),
         latent_codes[1][0].reshape(n, -1), latent_codes[1][1].reshape(n, -1), latent_codes[1][2].reshape(n, -1),
         latent_codes[2][0].reshape(n, -1), latent_codes[2][1].reshape(n, -1), latent_codes[2][2].reshape(n, -1)),
        dim = 1
    )
    all_grads = torch.cat(
        (latent_grads[0][0].reshape(n, -1), latent_grads[0][1].reshape(n, -1), latent_grads[0][2].reshape(n, -1),
         latent_grads[1][0].reshape(n, -1), latent_grads[1][1].reshape(n, -1), latent_grads[1][2].reshape(n, -1),
         latent_grads[2][0].reshape(n, -1), latent_grads[2][1].reshape(n, -1), latent_grads[2][2].reshape(n, -1)),
        dim = 1
    )
    
    all_latents.backward(all_grads)
        
    flow_opt.step()
    
    return -sum_ll, -sum_ll


def eval_step_logistic(flow_model, einsum_models, data):
    with torch.no_grad():
        ## Forward flow model ##
        pz, z, pys, ys, _ = flow_model.forward_only(data)

        ## Latent codes ##
        latent_codes = [
            [z, pz[0], pz[1]], [ys[0], pys[0][0], pys[0][1]], [ys[1], pys[1][0], pys[1][1]]
        ]

        ## Forward einsum models ##
        sum_ll = 0.0
        for idx, (einet, zs) in enumerate(zip(einsum_models, latent_codes)):
            n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
            h = zs[0].view(n, -1)
            h_mean = zs[1].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)
            h_logscale = zs[2].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1)

            ll_sample = einet.forward_with_grad2(h, h_mean, h_logscale)

            log_likelihood = ll_sample.mean()
            sum_ll += log_likelihood.detach().cpu().numpy()
    
    return -sum_ll


def eval_step_logistic_multi_gpus(flow_model, einsum_models, data, args):
    with torch.no_grad():
        ## Forward flow model ##
        pz, z, pys, ys, _ = flow_model.forward_only(data)

        ## Latent codes ##
        latent_codes = [
            [z, pz[0], pz[1]], [ys[0], pys[0][0], pys[0][1]], [ys[1], pys[1][0], pys[1][1]]
        ]

        ## Forward einsum models ##
        sum_ll = 0.0
        for idx, (einet, zs) in enumerate(zip(einsum_models, latent_codes)):
            n, k = zs[0].size(0), zs[0].size(1) * zs[0].size(2) * zs[0].size(3)
            h = zs[0].view(n, -1).to(args.einets_device)
            h_mean = zs[1].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1).to(args.einets_device)
            h_logscale = zs[2].reshape(n, -1).reshape(n, -1, k).permute(0, 2, 1).to(args.einets_device)

            ll_sample = einet.forward_with_grad2(h, h_mean, h_logscale)

            log_likelihood = ll_sample.mean()
            sum_ll += log_likelihood.detach().cpu().numpy()
    
    return sum_ll


def eval_step_baseline_mixture(flow_model, einsum_models, data):
    
    def log_mixture_discretized_logistic(x, mean, logscale, pi):
        scale = torch.exp(logscale)

        x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)

        p = torch.sigmoid((x + 0.5 / 256.0 - mean) / scale) \
            - torch.sigmoid((x - 0.5 / 256.0 - mean) / scale)

        p = torch.sum(p * pi, dim=-1)

        logp = torch.log(p + 1e-8)

        return logp
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(xlow, xhigh):
        logp = log_min_exp(
            F.logsigmoid(xhigh),
            F.logsigmoid(xlow))
        return logp
    
    ## Forward flow model ##
    _, _, _, pz, z, pys, ys, _ = flow_model(data)
    
    ## Latent codes ##
    ll_zs = [
        [z, pz[0], pz[1], pz[2]],
        [(ys[0] - pys[0][0] - 0.5 / 256.0) / torch.exp(pys[0][1]), 
         (ys[0] - pys[0][0] + 0.5 / 256.0) / torch.exp(pys[0][1])],
        [(ys[1] - pys[1][0] - 0.5 / 256.0) / torch.exp(pys[1][1]), 
         (ys[1] - pys[1][0] + 0.5 / 256.0) / torch.exp(pys[1][1])]
    ]
    
    logp = 0.0
    with torch.no_grad():
        for zs in ll_zs:
            if len(zs) == 4:
                logp += log_mixture_discretized_logistic(zs[0], zs[1], zs[2], zs[3]).mean(0).sum()
            else:
                logp += log_discretized_logistic(zs[0], zs[1]).mean(0).sum()
    
    return logp.detach().cpu().numpy()


def train_step_gaussian(flow_model, einsum_models, data, flow_opt):
    flow_opt.zero_grad()
    
    ## Forward flow model ##
    _, _, _, _, z, pys, ys, _ = flow_model(data)
    
    ## Latent codes ##
    latent_codes = [
        z,
        (ys[0] - pys[0][0]) / torch.exp(pys[0][1]),
        (ys[1] - pys[1][0]) / torch.exp(pys[1][1])
    ]
    latent_code_sizes = [item.size() for item in latent_codes]
    latent_codes_grad = [None for _ in range(len(einsum_models))]
    
    ## Forward einsum models ##
    sum_ll = 0.0
    for idx, (einet, zs) in enumerate(zip(einsum_models, latent_codes)):
        
        ll_sample = einet.forward_with_grad(zs.view(latent_code_sizes[idx][0], -1).detach())
        
        log_likelihood = ll_sample.sum()
        log_likelihood.backward()
        
        einet.em_process_batch()
        
        sum_ll += log_likelihood.detach().cpu().numpy()
        
        # Get zs gradient
        zs_grad = einet.get_input_x_grad().clone()
        latent_codes_grad[idx] = zs_grad
        
    ## Flow model backward ##
    for zs, zs_grad in zip(latent_codes, latent_codes_grad):
        zs.backward(-zs_grad.reshape(zs.size()) / zs.size(0), retain_graph = True)
        
    flow_opt.step()
    
    ## For computing LL/bpd only ##
    sum_discrete_ll = 0.0
    with torch.no_grad():
        ll_zs = [
            [z - 0.5 / 256.0, z + 0.5 / 256.0],
            [(ys[0] - pys[0][0] - 0.5 / 256.0) / torch.exp(pys[0][1]), 
             (ys[0] - pys[0][0] + 0.5 / 256.0) / torch.exp(pys[0][1])],
            [(ys[1] - pys[1][0] - 0.5 / 256.0) / torch.exp(pys[1][1]), 
             (ys[1] - pys[1][0] + 0.5 / 256.0) / torch.exp(pys[1][1])]
        ]
        
        for idx, (einet, zs) in enumerate(zip(einsum_models, ll_zs)):
            ll_sample = einet.forward3(zs[0].view(latent_code_sizes[idx][0], -1), zs[1].view(latent_code_sizes[idx][0], -1))
            sum_discrete_ll += ll_sample.sum().detach().cpu().numpy()
    
    return sum_ll / data.size(0), sum_discrete_ll / data.size(0)


def train_step_baseline(flow_model, einsum_models, data, flow_opt):
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(xlow, xhigh):
        logp = log_min_exp(
            F.logsigmoid(xhigh),
            F.logsigmoid(xlow))
        return logp
    
    flow_opt.zero_grad()
    
    ## Forward flow model ##
    _, _, _, pz, z, pys, ys, _ = flow_model(data)
    
    ## Latent codes ##
    ll_zs = [
        [(z - pz[0] - 0.5 / 256.0) / torch.exp(pz[1]), 
         (z - pz[0] + 0.5 / 256.0) / torch.exp(pz[1])],
        [(ys[0] - pys[0][0] - 0.5 / 256.0) / torch.exp(pys[0][1]), 
         (ys[0] - pys[0][0] + 0.5 / 256.0) / torch.exp(pys[0][1])],
        [(ys[1] - pys[1][0] - 0.5 / 256.0) / torch.exp(pys[1][1]), 
         (ys[1] - pys[1][0] + 0.5 / 256.0) / torch.exp(pys[1][1])]
    ]
    
    loss = 0.0
    for zs in ll_zs:
        loss -= log_discretized_logistic(zs[0], zs[1]).mean(0).sum()
    
    loss.backward()
        
    flow_opt.step()
    
    return -loss.detach().cpu().numpy(), -loss.detach().cpu().numpy()


def eval_step_baseline(flow_model, einsum_models, data):
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(xlow, xhigh):
        logp = log_min_exp(
            F.logsigmoid(xhigh),
            F.logsigmoid(xlow))
        return logp
    
    with torch.no_grad():
    
        ## Forward flow model ##
        _, _, _, pz, z, pys, ys, _ = flow_model(data)

        ## Latent codes ##
        ll_zs = [
            [(z - pz[0] - 0.5 / 256.0) / torch.exp(pz[1]), 
             (z - pz[0] + 0.5 / 256.0) / torch.exp(pz[1])],
            [(ys[0] - pys[0][0] - 0.5 / 256.0) / torch.exp(pys[0][1]), 
             (ys[0] - pys[0][0] + 0.5 / 256.0) / torch.exp(pys[0][1])],
            [(ys[1] - pys[1][0] - 0.5 / 256.0) / torch.exp(pys[1][1]), 
             (ys[1] - pys[1][0] + 0.5 / 256.0) / torch.exp(pys[1][1])]
        ]

        loss = 0.0
        for zs in ll_zs:
            loss -= log_discretized_logistic(zs[0], zs[1]).mean(0).sum()

        return -loss.detach().cpu().numpy()


def train_step_baseline_mixture(flow_model, einsum_models, data, flow_opt):
    
    def log_mixture_discretized_logistic(x, mean, logscale, pi):
        scale = torch.exp(logscale)

        x = x.view(x.size(0), x.size(1), x.size(2), x.size(3), 1)

        p = torch.sigmoid((x + 0.5 / 256.0 - mean) / scale) \
            - torch.sigmoid((x - 0.5 / 256.0 - mean) / scale)

        p = torch.sum(p * pi, dim=-1)

        logp = torch.log(p + 1e-8)

        return logp
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(xlow, xhigh):
        logp = log_min_exp(
            F.logsigmoid(xhigh),
            F.logsigmoid(xlow))
        return logp
    
    flow_opt.zero_grad()
    
    ## Forward flow model ##
    _, _, _, pz, z, pys, ys, _ = flow_model(data)
    
    ## Latent codes ##
    ll_zs = [
        [z, pz[0], pz[1], pz[2]],
        [(ys[0] - pys[0][0] - 0.5 / 256.0) / torch.exp(pys[0][1]), 
         (ys[0] - pys[0][0] + 0.5 / 256.0) / torch.exp(pys[0][1])],
        [(ys[1] - pys[1][0] - 0.5 / 256.0) / torch.exp(pys[1][1]), 
         (ys[1] - pys[1][0] + 0.5 / 256.0) / torch.exp(pys[1][1])]
    ]
    
    loss = 0.0
    for zs in ll_zs:
        if len(zs) == 4:
            loss -= log_mixture_discretized_logistic(zs[0], zs[1], zs[2], zs[3]).mean(0).sum()
        else:
            loss -= log_discretized_logistic(zs[0], zs[1]).mean(0).sum()
    
    loss.backward()
        
    flow_opt.step()
    
    return -loss.detach().cpu().numpy(), -loss.detach().cpu().numpy()