import numpy as np
import scipy.misc
import matplotlib.pyplot as plt
from data import maf_data as datasets
from .pdfs import *


# set paths
root_output = "output/"  # where to save trained models
root_data = "data/"  # where the datasets are

# holders for the datasets
data = None
data_name = None

# parameters for training
minibatch = 100
patience = 30
monitor_every = 1
weight_decay_rate = 1.0e-6
a_made = 1.0e-3
a_flow = 1.0e-4


def load_data(name):
    """
    Loads the dataset. Has to be called before anything else.
    :param name: string, the dataset's name
    """

    assert isinstance(name, str), "Name must be a string"
    datasets.root = root_data
    global data, data_name

    if data_name == name:
        return

    if name == "mnist":
        data = datasets.MNIST(logit=True, dequantize=True)
        data_name = name

    elif name == "bsds300":
        data = datasets.BSDS300()
        data_name = name

    elif name == "cifar10":
        data = datasets.CIFAR10(logit=True, flip=True, dequantize=True)
        data_name = name

    elif name == "power":
        data = datasets.POWER()
        data_name = name

    elif name == "gas":
        data = datasets.GAS()
        data_name = name

    elif name == "hepmass":
        data = datasets.HEPMASS()
        data_name = name

    elif name == "miniboone":
        data = datasets.MINIBOONE()
        data_name = name

    else:
        raise ValueError("Unknown dataset")


def is_data_loaded():
    """
    Checks whether a dataset has been loaded.
    :return: boolean
    """
    return data_name is not None


def evaluate(
    predict_density_fun,
    split,
    n_samples=None,
    generate_samples_fun=None,
    is_conditional=False,
):
    """
    Evaluate a trained model.
    :param model: the model to evaluate. Can be any made, maf, or real nvp
    :param split: string, the data split to evaluate on. Must be 'trn', 'val' or 'tst'
    :param n_samples: number of samples to generate from the model, or None for no samples
    """

    assert is_data_loaded(), "Dataset hasn't been loaded"

    # choose which data split to evaluate on
    data_split = getattr(data, split, None)
    if data_split is None:
        raise ValueError("Invalid data split")

    if is_conditional:

        # calculate log probability
        logprobs = predict_density_fun([data_split.y, data_split.x])
        print(
            "logprob(x|y) = {0:.2f} +/- {1:.2f}".format(
                logprobs.mean(), 2 * logprobs.std() / np.sqrt(data_split.N)
            )
        )

        # classify test set
        logprobs = np.empty([data_split.N, data.n_labels])
        for i in range(data.n_labels):
            y = np.zeros([data_split.N, data.n_labels])
            y[:, i] = 1
            logprobs[:, i] = predict_density_fun([y, data_split.x])
        predict_label = np.argmax(logprobs, axis=1)
        accuracy = (predict_label == data_split.labels).astype(float)
        logprobs = scipy.misc.logsumexp(logprobs, axis=1) - np.log(logprobs.shape[1])
        print(
            "logprob(x) = {0:.2f} +/- {1:.2f}".format(
                logprobs.mean(), 2 * logprobs.std() / np.sqrt(data_split.N)
            )
        )
        print(
            "classification accuracy = {0:.2%} +/- {1:.2%}".format(
                accuracy.mean(), 2 * accuracy.std() / np.sqrt(data_split.N)
            )
        )

        # generate data conditioned on label
        if n_samples is not None:
            for i in range(data.n_labels):

                # generate samples and sort according to log prob
                y = np.zeros(data.n_labels)
                y[i] = 1
                samples = generate_samples_fun(y, n_samples)
                lp_samples = predict_density_fun([np.tile(y, [n_samples, 1]), samples])
                lp_samples = lp_samples[np.logical_not(np.isnan(lp_samples))]
                idx = np.argsort(lp_samples)
                samples = samples[idx][::-1]

                if data_name == "mnist":
                    samples = (util.logistic(samples) - data.alpha) / (
                        1 - 2 * data.alpha
                    )

                elif data_name == "bsds300":
                    samples = np.hstack(
                        [samples, -np.sum(samples, axis=1)[:, np.newaxis]]
                    )

                elif data_name == "cifar10":
                    samples = (util.logistic(samples) - data.alpha) / (
                        1 - 2 * data.alpha
                    )
                    D = int(data.n_dims / 3)
                    r = samples[:, :D]
                    g = samples[:, D : 2 * D]
                    b = samples[:, 2 * D :]
                    samples = np.stack([r, g, b], axis=2)

                else:
                    raise ValueError("non-image dataset")

                util.disp_imdata(samples, data.image_size, [5, 8])

    else:

        # calculate average log probability
        logprobs = predict_density_fun(data_split.x)
        print(
            "logprob(x) = {0:.2f} +/- {1:.2f}".format(
                logprobs.mean(), 2 * logprobs.std() / np.sqrt(data_split.N)
            )
        )

        # generate data
        if n_samples is not None:

            # generate samples and sort according to log prob
            samples = generate_samples_fun(n_samples)
            lp_samples = predict_density_fun(samples)
            lp_samples = lp_samples[np.logical_not(np.isnan(lp_samples))]
            idx = np.argsort(lp_samples)
            samples = samples[idx][::-1]

            if data_name == "mnist":
                samples = (util.logistic(samples) - data.alpha) / (1 - 2 * data.alpha)

            elif data_name == "bsds300":
                samples = np.hstack([samples, -np.sum(samples, axis=1)[:, np.newaxis]])

            elif data_name == "cifar10":
                samples = (util.logistic(samples) - data.alpha) / (1 - 2 * data.alpha)
                D = int(data.n_dims / 3)
                r = samples[:, :D]
                g = samples[:, D : 2 * D]
                b = samples[:, 2 * D :]
                samples = np.stack([r, g, b], axis=2)

            else:
                raise ValueError("non-image dataset")

            util.disp_imdata(samples, data.image_size, [5, 8])

    plt.show()


def evaluate_logprob(
    predict_density_fun,
    split,
    is_conditional,
    use_image_space=False,
    return_avg=True,
    batch=2000,
):
    """
    Evaluate a trained model only in terms of log probability.
    :param model: the model to evaluate. Can be any made, maf, or real nvp
    :param split: string, the data split to evaluate on. Must be 'trn', 'val' or 'tst'
    :param use_image_space: bool, whether to report log probability in [0, 1] image space (only for cifar and mnist)
    :param return_avg: bool, whether to return average log prob with std error, or all log probs
    :param batch: batch size to use for computing log probability
    :return: average log probability & standard error, or all log probs
    """

    assert is_data_loaded(), "Dataset hasn't been loaded"

    # choose which data split to evaluate on
    data_split = getattr(data, split, None)
    if data_split is None:
        raise ValueError("Invalid data split")

    if is_conditional:

        logprobs = np.empty([data_split.N, data.n_labels])

        for i in range(data.n_labels):

            # create labels
            y = np.zeros([data_split.N, data.n_labels])
            y[:, i] = 1

            # process data in batches to make sure they fit in memory
            r, l = 0, batch
            while r < data_split.N:
                logprobs[r:l, i] = predict_density_fun([y[r:l], data_split.x[r:l]])
                l += batch
                r += batch

        logprobs = scipy.misc.logsumexp(logprobs, axis=1) - np.log(logprobs.shape[1])

    else:

        logprobs = np.empty(data_split.N)

        # process data in batches to make sure they fit in memory
        r, l = 0, batch
        while r < data_split.N:
            logprobs[r:l] = predict_density_fun(data_split.x[r:l])
            l += batch
            r += batch

    if use_image_space:
        assert data_name in ["mnist", "cifar10"]
        z = util.logistic(data_split.x)
        logprobs += data.n_dims * np.log(1 - 2 * data.alpha) - np.sum(
            np.log(z) + np.log(1 - z), axis=1
        )

    if return_avg:
        avg_logprob = logprobs.mean()
        std_err = logprobs.std() / np.sqrt(data_split.N)
        return avg_logprob, std_err

    else:
        return logprobs


def fit_and_evaluate_gaussian(
    split, cond=False, use_image_space=False, return_avg=True
):
    """
    Fits a gaussian to the train data and evaluates it on the given split.
    :param split: the data split to evaluate on, must be 'trn', 'val', or 'tst'
    :param cond: boolean, whether to fit a gaussian per conditional
    :param use_image_space: bool, whether to report log probability in [0, 1] image space (only for cifar and mnist)
    :param return_avg: bool, whether to return average log prob with std error, or all log probs
    :return: average log probability & standard error, or all lop probs
    """

    assert is_data_loaded(), "Dataset hasn't been loaded"

    # choose which data split to evaluate on
    data_split = getattr(data, split, None)
    if data_split is None:
        raise ValueError("Invalid data split")

    if cond:
        comps = []
        for i in range(data.n_labels):
            idx = data.trn.labels == i
            comp = pdfs.fit_gaussian(data.trn.x[idx])
            comps.append(comp)
        prior = np.ones(data.n_labels, dtype=float) / data.n_labels
        model = pdfs.MoG(prior, xs=comps)

    else:
        model = pdfs.fit_gaussian(data.trn.x)

    logprobs = model.eval(data_split.x)

    if use_image_space:
        assert data_name in ["mnist", "cifar10"]
        z = util.logistic(data_split.x)
        logprobs += data.n_dims * np.log(1 - 2 * data.alpha) - np.sum(
            np.log(z) + np.log(1 - z), axis=1
        )

    if return_avg:
        avg_logprob = logprobs.mean()
        std_err = logprobs.std() / np.sqrt(data_split.N)
        return avg_logprob, std_err

    else:
        return logprobs
