import pdb
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils import weight_norm as wn
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict

def ll_to_bpd(ll, shape=(32, 32, 3), bits=8):
    n_pixels = np.prod(shape)

    bpd = -((ll / n_pixels) - np.log(2 ** (bits - 1))) / np.log(2)
    return bpd

class EMA(nn.Module):
    def __init__(self, model: nn.Module, shadow: nn.Module, decay: float):
        super().__init__()
        self.decay = decay

        self.model = model
        self.shadow = shadow

        self.update(copy_all=True)

        for param in self.shadow.parameters():
            param.detach_()

    @torch.no_grad()
    def update(self, copy_all=False):
        if not self.training:
            print("EMA update should only be called during training", file=stderr, flush=True)
            return

        model_params = OrderedDict(self.model.named_parameters())
        shadow_params = OrderedDict(self.shadow.named_parameters())

        # check if both model contains the same set of keys
        assert model_params.keys() == shadow_params.keys()

        for name, param in model_params.items():
            if copy_all:
                shadow_params[name].copy_(param)
            else:
                shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))

        model_buffers = OrderedDict(self.model.named_buffers())
        shadow_buffers = OrderedDict(self.shadow.named_buffers())

        # check if both model contains the same set of keys
        assert model_buffers.keys() == shadow_buffers.keys()

        for name, buffer in model_buffers.items():
            # buffers are copied
            shadow_buffers[name].copy_(buffer)

    def forward(self, *args, **kwargs):
        if self.training:
            return self.model(*args, **kwargs)
        else:
            return self.shadow(*args, **kwargs)

#     def __setattr__(self, attr):
#         print(attr)
#         self.__dict

class PNGDataset(torch.utils.data.Dataset):
    """Gaussian."""

    def __init__(self, data_dir='/data/datasets', transform=None):
        """
        Args:
            root (string): Path to the root directory of the ImageNet dataset.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        To obtain ImageNet dataset files, download ImageNet 64x64 train and val from:
        https://image-net.org/download-images.php.
        """
        self.root = root
        self.transform = transform

        self.x, self.y = self.get_data(train)

    def get_data(self, train):
        x = []
        for i in range(10):
            with open(self.root + '/train/train_data_batch_{}'.format(i + 1), 'rb') as f:
                batch_dict = pickle.load(f)
                x.append(batch_dict['data'])
                y.append(batch_dict['labels'])

        x = np.concatenate(x, axis=0)

        x, y = torch.tensor(x), torch.tensor(y)
        x = x.reshape(-1, 3, 64, 64).float() / 255.
        return x, y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.x[idx]

        if self.transform:
            image = self.transform(image)

        return image, self.y[idx]

class ImageNet(torch.utils.data.Dataset):
    """Gaussian."""

    def __init__(self, root='/data/datasets/ImageNet', train=True, transform=None, zip_mode=False):
        """
        Args:
            root (string): Path to the root directory of the ImageNet dataset.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        To obtain ImageNet dataset files, download ImageNet 64x64 train and val from:
        https://image-net.org/download-images.php.
        """
        self.root = root
        self.transform = transform

        self.x, self.y = self.get_data(train)

    def get_data(self, train):
        if train:
            x = []
            y = []
            for i in range(10):
                with open(self.root + '/train/train_data_batch_{}'.format(i + 1), 'rb') as f:
                    batch_dict = pickle.load(f)
                    x.append(batch_dict['data'])
                    y.append(batch_dict['labels'])

            x = np.concatenate(x, axis=0)
            y = np.concatenate(x, axis=0)
        else:
            with open(self.root + '/val/val_data', 'rb') as f:
                batch_dict = pickle.load(f)
                x = batch_dict['data']
                y = batch_dict['labels']

        x, y = torch.tensor(x), torch.tensor(y)
        x = x.reshape(-1, 3, 64, 64).float() / 255.
        return x, y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.x[idx]

        if self.transform:
            image = self.transform(image)

        return image, self.y[idx]

class GaussianDataset(torch.utils.data.Dataset):
    """Gaussian."""

    def __init__(self, shape, n=60000, loc=None, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.n = n
        self.shape = shape
        self.transform = transform
        self.loc = np.clip(np.random.normal(size=(1,) + shape), -1., 1.).astype('f') if loc is None else loc
        self.data = np.ones((n,) + shape).astype('f') * self.loc

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.data[idx]

        if self.transform:
            image = self.transform(image)

        return image, 0.

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.data[idx]

        if self.transform:
            image = self.transform(image)

        return image

class DistributedDataParallel(nn.parallel.DistributedDataParallel):
    def update(self):
        self.module.update()

def save_model(model, optimizer, epoch, save_dir):
    if isinstance(model, DistributedDataParallel):
        module = model.module
        assert isinstance(module, EMA)
    elif isinstance(model, EMA):
        module = model
    else:
        print(type(model), " is not a supported model class for saving. Saving model directly...")
        module = model

    torch.save({
            'epoch': epoch,
            'model_state_dict': module.model.state_dict(),
            'shadow_state_dict': module.shadow.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, save_dir)

def load_model_robust(model, state_dict):
    n_loaded = 0
    n_tot = len(model.state_dict())
    for name, param in state_dict.items():
        if name in model.state_dict():
            model.state_dict()[name].copy_(param)
            n_loaded += 1
    print('added {:.2f}% of params:'.format(n_loaded / n_tot * 100))

def load_model(model, state_dict):
    try:
        # first try fast approach:
        model.load_state_dict(state_dict)
    except RuntimeError as e:
        print("Could not load model normally. Using robust approach...")
        print(e)
        load_model_robust(model, state_dict)

def load_optimizer(optimizer, state_dict, gpu):
    new_lr = optimizer.param_groups[0]['lr']
    optimizer.load_state_dict(state_dict)
    # set the appropriate optimizer parameters to the appropriate device
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr
        for param in param_group['params']:
            param_dict = optimizer.state[param]
            for k in param_dict:
                param_dict[k] = param_dict[k].to(gpu)

    return optimizer

def maybe_load_chkpt(model, optimizer, load_dir, gpu):
    if load_dir is None:
        return model, optimizer, 0

    gpu = 'cuda:{}'.format(gpu) if isinstance(gpu, int) else gpu

    checkpoint = torch.load(load_dir, map_location=gpu)
    epoch = checkpoint['epoch']
    try:
        optimizer = load_optimizer(optimizer, checkpoint['optimizer_state_dict'], gpu)
    except Exception as e:
        print("Couldn't load optimizer!")
        print(e)
        pass
    load_model(model.model, checkpoint['model_state_dict'])
    load_model(model.shadow, checkpoint['shadow_state_dict'])

    del checkpoint
    print("model successfully loaded!")

    return model, optimizer, epoch

def maybe_load_chkpt_old(model, optimizer, load_dir, gpu):
    if load_dir is None:
        return model, optimizer, 0

    gpu = 'cuda:{}'.format(gpu) if isinstance(gpu, int) else gpu

    state_dict = torch.load(load_dir, map_location=gpu)
    load_model(model, state_dict)
    del state_dict

    try:
        epoch = int(load_dir.split('_')[-1].split('.')[0])
    except:
        epoch = 0

    print("model successfully loaded!")

    return model, optimizer, epoch

def plt_imgs(sample, title=''):
    n_samples = min(10, len(sample))
    fig, axs = plt.subplots(nrows=1, ncols=n_samples, figsize=(n_samples * 2, 2))
    plt.suptitle(title)
    for i in range(n_samples):
        axs[i].axis('off')
        axs[i].imshow((sample[i].permute(1, 2, 0).cpu() + 1) / 2, cmap='gray')

def concat_elu(x, dim=1):
    """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """
    return F.elu(torch.cat([x, -x], dim=dim))

def log_sum_exp(x):
    """ numerically stable log_sum_exp implementation that prevents overflow """
    # TF ordering
    axis  = len(x.size()) - 1
    m, _  = torch.max(x, dim=axis)
    m2, _ = torch.max(x, dim=axis, keepdim=True)
    return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))


def log_prob_from_logits(x):
    """ numerically stable log_softmax implementation that prevents overflow """
    # TF ordering
    axis = len(x.size()) - 1
    m, _ = torch.max(x, dim=axis, keepdim=True)
    return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True))

def bit_normalize(x):
    x = x.clamp(min=-1, max=1.)
    x = (x + 1) * 127.5
    x = x.round()
    x = (x / 127.5) - 1
    return x

def discretized_mix_logistic_loss(x, l, per_sample=False, bin_pixels=False, discrete=True):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    if discrete and bin_pixels:
        x = bit_normalize(x)
    # Pytorch ordering
    x = x.permute(0, 2, 3, 1)
    l = l.permute(0, 2, 3, 1)
    xs = [int(y) for y in x.size()]
    ls = [int(y) for y in l.size()]

    # here and below: unpacking the params of the mixture of logistics
    nr_mix = int(ls[-1] / 10)
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) # 3 for mean, scale, coef
    means = l[:, :, :, :, :nr_mix]
    # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
    log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)

    coeffs = l[:, :, :, :, 2 * nr_mix:3 * nr_mix].tanh()
    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.contiguous()
    x = x.unsqueeze(-1) + torch.zeros(xs + [nr_mix]).to(x.device)
    m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :]
                * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)

    m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] +
                coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)

    means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
    centered_x = x - means

    # log probability in the center of the bin, to be used in extreme cases
    # (not actually used in our code)
    inv_stdv = torch.exp(-log_scales)
    mid_in = inv_stdv * centered_x
    log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
#     log_pdf_mid = mid_in + log_scales - 2. * F.softplus(0.5 * mid_in)

    if discrete:
        plus_in = inv_stdv * (centered_x + 1. / 255.)
        cdf_plus = plus_in.sigmoid()
        min_in = inv_stdv * (centered_x - 1. / 255.)
        cdf_min = min_in.sigmoid()
        # log probability for edge case of 0 (before scaling)
        log_cdf_plus = plus_in - F.softplus(plus_in)
        # log probability for edge case of 255 (before scaling)
        log_one_minus_cdf_min = -F.softplus(min_in)
        cdf_delta = cdf_plus - cdf_min  # probability for all other cases

        inner_inner_cond = (cdf_delta > 1e-5).float()
        inner_inner_out  = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5))
        inner_cond       = (x > 0.999).float()
        inner_out        = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
        cond             = (x < -0.999).float()
        log_probs        = cond * log_cdf_plus + (1. - cond) * inner_out
        log_probs        = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
        log_probs        = log_sum_exp(log_probs)
#         log_probs        = log_pdf_mid - np.log(127.5)
#         log_probs        = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
#         log_probs        = log_sum_exp(log_probs)
    else:
        log_probs        = log_pdf_mid
        log_probs        = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
        log_probs        = log_sum_exp(log_probs)

    if per_sample:
        return -log_probs
    else:
        return -torch.sum(log_probs)

# def discretized_mix_logistic_loss(x, l, per_sample=False, discrete=True):
#     """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
#     # Pytorch ordering
#     x = x.permute(0, 2, 3, 1)
#     l = l.permute(0, 2, 3, 1)
#     xs = [int(y) for y in x.size()]
#     ls = [int(y) for y in l.size()]

#     # here and below: unpacking the params of the mixture of logistics
#     nr_mix = int(ls[-1] / 10)
#     logit_probs = l[:, :, :, :nr_mix]
#     l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) # 3 for mean, scale, coef
#     means = l[:, :, :, :, :nr_mix]
#     # log_scales = torch.max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
#     log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)

#     coeffs = l[:, :, :, :, 2 * nr_mix:3 * nr_mix].tanh()
#     # here and below: getting the means and adjusting them based on preceding
#     # sub-pixels
#     x = x.contiguous()
#     x = x.unsqueeze(-1) + Variable(torch.zeros(xs + [nr_mix]).cuda(), requires_grad=False)
#     m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :]
#                 * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)

#     m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] +
#                 coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)

#     means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
#     centered_x = x - means
#     inv_stdv = torch.exp(-log_scales)
#     plus_in = inv_stdv * (centered_x + 1. / 255.)
#     cdf_plus = plus_in.sigmoid()
#     min_in = inv_stdv * (centered_x - 1. / 255.)
#     cdf_min = min_in.sigmoid()
#     # log probability for edge case of 0 (before scaling)
#     log_cdf_plus = plus_in - F.softplus(plus_in)
#     # log probability for edge case of 255 (before scaling)
#     log_one_minus_cdf_min = -F.softplus(min_in)
#     cdf_delta = cdf_plus - cdf_min  # probability for all other cases
#     mid_in = inv_stdv * centered_x
#     # log probability in the center of the bin, to be used in extreme cases
#     # (not actually used in our code)
#     log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

#     # now select the right output: left edge case, right edge case, normal
#     # case, extremely low prob case (doesn't actually happen for us)

#     # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
#     # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))

#     # robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
#     # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
#     # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
#     # if the probability on a sub-pixel is below 1e-5, we use an approximation
#     # based on the assumption that the log-density is constant in the bin of
#     # the observed sub-pixel value

#     inner_inner_cond = (cdf_delta > 1e-5).float()
#     inner_inner_out  = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5))
#     inner_cond       = (x > 0.999).float()
#     inner_out        = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
#     cond             = (x < -0.999).float()
#     log_probs        = cond * log_cdf_plus + (1. - cond) * inner_out
#     log_probs        = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
#     log_probs        = log_sum_exp(log_probs)

#     if per_sample:
#         return -log_probs
#     else:
#         return -torch.sum(log_probs)


def discretized_mix_logistic_loss_1d(x, l, nr_mix=None, per_sample=False):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    # Pytorch ordering
    x = x.permute(0, 2, 3, 1)
    l = l.permute(0, 2, 3, 1)
    xs = [int(y) for y in x.size()]
    ls = [int(y) for y in l.size()]

    # here and below: unpacking the params of the mixture of logistics
    nr_mix = int(ls[-1] / 3) if nr_mix is None else nr_mix
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # 2 for mean, scale
    means = l[:, :, :, :, :nr_mix]
    log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)
    # here and below: getting the means and adjusting them based on preceding
    # sub-pixels
    x = x.contiguous()
    x = x.unsqueeze(-1) + torch.zeros(xs + [nr_mix]).to(x.device)

    # means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1. / 255.)
    cdf_plus = F.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = F.sigmoid(min_in)
    # log probability for edge case of 0 (before scaling)
    log_cdf_plus = plus_in - F.softplus(plus_in)
    # log probability for edge case of 255 (before scaling)
    log_one_minus_cdf_min = -F.softplus(min_in)
    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_x
    # log probability in the center of the bin, to be used in extreme cases
    # (not actually used in our code)
    log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

    inner_inner_cond = (cdf_delta > 1e-5).float()
    inner_inner_out  = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (log_pdf_mid - np.log(127.5))
    inner_cond       = (x > 0.999).float()
    inner_out        = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
    cond             = (x < -0.999).float()
    log_probs        = cond * log_cdf_plus + (1. - cond) * inner_out
    log_probs        = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
    log_probs        = log_sum_exp(log_probs)

    if per_sample:
        return -log_probs
    else:
        return -torch.sum(log_probs)


def to_one_hot(tensor, n, fill_with=1.):
    # we perform one hot encore with respect to the last axis
    one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
    if tensor.is_cuda : one_hot = one_hot.cuda()
    one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
    return Variable(one_hot)


def sample_from_discretized_mix_logistic_1d(l, nr_mix):
    # Pytorch ordering
    l = l.permute(0, 2, 3, 1)
    ls = [int(y) for y in l.size()]
    xs = ls[:-1] + [1] #[3]

    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # for mean, scale

    # sample mixture indicator from softmax
    temp = torch.FloatTensor(logit_probs.size())
    if l.is_cuda : temp = temp.cuda()
    temp.uniform_(1e-5, 1. - 1e-5)
    temp = logit_probs.data - torch.log(- torch.log(temp))
    _, argmax = temp.max(dim=3)

    one_hot = to_one_hot(argmax, nr_mix)
    sel = one_hot.view(xs[:-1] + [1, nr_mix])
    # select logistic parameters
    means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
    log_scales = torch.clamp(torch.sum(
        l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.)
    u = torch.FloatTensor(means.size())
    if l.is_cuda : u = u.cuda()
    u.uniform_(1e-5, 1. - 1e-5)
    u = Variable(u)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
    x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
    out = x0.unsqueeze(1)
    return out


def sample_from_discretized_mix_logistic(l, nr_mix, bounds=1.):
    # Pytorch ordering
    l = l.permute(0, 2, 3, 1)
    ls = [int(y) for y in l.size()]
    xs = ls[:-1] + [3]

    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
    # sample mixture indicator from softmax
    temp = torch.FloatTensor(logit_probs.size())
    if l.is_cuda : temp = temp.cuda()
    temp.uniform_(1e-5, 1. - 1e-5)
    temp = logit_probs.data - torch.log(- torch.log(temp))
    _, argmax = temp.max(dim=3)

    one_hot = to_one_hot(argmax, nr_mix)
    sel = one_hot.view(xs[:-1] + [1, nr_mix])
    # select logistic parameters
    means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
    log_scales = torch.clamp(torch.sum(
        l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.)
    coeffs = torch.sum(
        l[:, :, :, :, 2 * nr_mix:3 * nr_mix].tanh() * sel, dim=4)
    # sample from logistic & clip to interval
    # we don't actually round to the nearest 8bit value when sampling
    u = torch.FloatTensor(means.size())
    if l.is_cuda : u = u.cuda()
    u.uniform_(1e-5, 1. - 1e-5)
    u = Variable(u)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
    x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-bounds), max=bounds)
    x1 = torch.clamp(torch.clamp(
       x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-bounds), max=bounds)
    x2 = torch.clamp(torch.clamp(
       x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-bounds), max=bounds)

    out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3)
    # put back in Pytorch ordering
    out = out.permute(0, 3, 1, 2)
    return out



''' utilities for shifting the image around, efficient alternative to masking convolutions '''
def down_shift(x, pad=None):
    # Pytorch ordering
    xs = [int(y) for y in x.size()]
    # when downshifting, the last row is removed
    x = x[:, :, :xs[2] - 1, :]
    # padding left, padding right, padding top, padding bottom
    pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad
    return pad(x)


def right_shift(x, pad=None):
    # Pytorch ordering
    xs = [int(y) for y in x.size()]
    # when righshifting, the last column is removed
    x = x[:, :, :, :xs[3] - 1]
    # padding left, padding right, padding top, padding bottom
    pad = nn.ZeroPad2d((1, 0, 0, 0)) if pad is None else pad
    return pad(x)


def load_part_of_model(model, path):
    params = torch.load(path)
    added = 0
    for name, param in params.items():
        if name in model.state_dict().keys():
            try :
                model.state_dict()[name].copy_(param)
                added += 1
            except Exception as e:
                print(e)
                pass
    print('added %s of params:' % (added / float(len(model.state_dict().keys()))))
