from __future__ import print_function, division
import torch
from torch.autograd import Variable
# from PIL import Image
import torch.utils.data as data
import numpy as np
import os
import sys
try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import torch.nn.functional as F

EPS = torch.finfo(torch.float64).tiny

def w_to_std(w, beta=1, threshold=20):
    std_w = EPS + F.softplus(w, beta=beta, threshold=threshold)
    return std_w

def sample_weights(W_mu, b_mu, W_p, b_p):
    """Quick method for sampling weights and exporting weights
    
    Sampling W from N(W_mu, std_w^2) as follows:
        eps_W ~ N(0, 1^2)
        std_w = EPS + log(1+exp(W_p)) (if W_p > 20, std_w = EPS + W_p)
        W = W_mu + 1 * std_w * eps_W

    Sampling b from N(b_mu, std_b^2) as follows:
        eps_b ~ N(0, 1^2)
        std_b = EPS + log(1+exp(b_p)) (if b_p > 20, std_w = EPS + b_p)
        b = b_mu + 1 * std_b * eps_b

    This function samples b only if b_mu is not `None`
    """
    eps_W = W_mu.data.new(W_mu.size()).normal_()
    # sample parameters
    std_w = w_to_std(W_p)
    W = W_mu + 1 * std_w * eps_W

    if b_mu is not None:
        std_b = w_to_std(b_p)
        eps_b = b_mu.data.new(b_mu.size()).normal_()
        b = b_mu + 1 * std_b * eps_b
    else:
        b = None

    return W, b


def load_object(filename):
    with open(filename, 'rb') as input:
        return pickle.load(input)


def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)


def mkdir(paths):
    if not isinstance(paths, (list, tuple)):
        paths = [paths]
    for path in paths:
        if not os.path.isdir(path):
            os.makedirs(path)


suffixes = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']


def humansize(nbytes):
    i = 0
    while nbytes >= 1024 and i < len(suffixes) - 1:
        nbytes /= 1024.
        i += 1
    f = ('%.2f' % nbytes)
    return '%s%s' % (f, suffixes[i])


def get_num_batches(nb_samples, batch_size, roundup=True):
    if roundup:
        return ((nb_samples + (-nb_samples % batch_size)) / batch_size)  # roundup division
    else:
        return nb_samples / batch_size


def generate_ind_batch(nb_samples, batch_size, random=True, roundup=True):
    if random:
        ind = np.random.permutation(nb_samples)
    else:
        ind = range(int(nb_samples))
    for i in range(int(get_num_batches(nb_samples, batch_size, roundup))):
        yield ind[i * batch_size: (i + 1) * batch_size]


def to_variable(var=(), cuda=True, volatile=False):
    out = []
    for v in var:
        if isinstance(v, np.ndarray):
            v = torch.from_numpy(v).type(torch.FloatTensor)

        if not v.is_cuda and cuda:
            v = v.cuda()

        if not isinstance(v, Variable):
            v = Variable(v, volatile=volatile)

        out.append(v)
    return out


def cprint(color, text, **kwargs):
    if color[0] == '*':
        pre_code = '1;'
        color = color[1:]
    else:
        pre_code = ''
    code = {
        'a': '30',
        'r': '31',
        'g': '32',
        'y': '33',
        'b': '34',
        'p': '35',
        'c': '36',
        'w': '37'
    }
    print("\x1b[%s%sm%s\x1b[0m" % (pre_code, code[color], text), **kwargs)
    sys.stdout.flush()


def shuffle_in_unison_scary(a, b):
    rng_state = np.random.get_state()
    np.random.shuffle(a)
    np.random.set_state(rng_state)
    np.random.shuffle(b)


class Datafeed(data.Dataset):

    def __init__(self, x_train, y_train, transform=None):
        self.x_train = x_train
        self.y_train = y_train
        self.transform = transform

    def __getitem__(self, index):
        img = self.x_train[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, self.y_train[index]

    def __len__(self):
        return len(self.x_train)

# class DatafeedImage(data.Dataset):
#     def __init__(self, x_train, y_train, transform=None):
#         self.x_train = x_train
#         self.y_train = y_train
#         self.transform = transform

#     def __getitem__(self, index):
#         img = self.x_train[index]
#         img = Image.fromarray(np.uint8(img))
#         if self.transform is not None:
#             img = self.transform(img)
#         return img, self.y_train[index]

#     def __len__(self):
#         return len(self.x_train)


# functions for BNN with gauss output: ###

def diagonal_gauss_loglike(x, mu, sigma):
    # note that we can just treat each dim as isotropic and then do sum
    cte_term = -(0.5) * np.log(2 * np.pi)
    det_sig_term = -torch.log(sigma)
    inner = (x - mu) / sigma
    dist_term = -(0.5) * (inner**2)
    log_px = (cte_term + det_sig_term + dist_term).sum(dim=1, keepdim=False)
    return log_px

def get_rms(mu, y, y_means, y_stds):
    x_un = mu * y_stds + y_means
    y_un = y * y_stds + y_means
    return torch.sqrt(((x_un - y_un)**2).sum() / y.shape[0])


def get_loglike(mu, sigma, y, y_means, y_stds):
    mu_un = mu * y_stds + y_means
    y_un = y * y_stds + y_means
    sigma_un = sigma * y_stds
    ll = diagonal_gauss_loglike(y_un, mu_un, sigma_un)
    return ll.mean(dim=0)
