#%%
from utils.pdfs import *

# region head
import sys
import itertools
import numpy as np
from utils.maf_utils import load_data
import utils
from utils import pdfs
import data.maf_data
from data import maf_data as datasets
import pickle as pkl
from time import time
import os

# set paths
root_output = "output/"  # where to save trained models
root_data = "data/"  # where the datasets are

# holders for the datasets
data = None
data_name = None

# parameters for training
minibatch = 100
patience = 30
monitor_every = 1
weight_decay_rate = 1.0e-6
a_made = 1.0e-3
a_flow = 1.0e-4


def load_data(name):
    """
    Loads the dataset. Has to be called before anything else.
    :param name: string, the dataset's name
    """

    assert isinstance(name, str), "Name must be a string"
    datasets.root = root_data
    global data, data_name

    if data_name == name:
        return

    if name == "mnist":
        data = datasets.MNIST(logit=True, dequantize=True)
        data_name = name

    elif name == "bsds300":
        data = datasets.BSDS300()
        data_name = name

    elif name == "cifar10":
        data = datasets.CIFAR10(logit=True, flip=True, dequantize=True)
        data_name = name

    elif name == "power":
        data = datasets.POWER()
        data_name = name

    elif name == "gas":
        data = datasets.GAS()
        data_name = name

    elif name == "hepmass":
        data = datasets.HEPMASS()
        data_name = name

    elif name == "miniboone":
        data = datasets.MINIBOONE()
        data_name = name

    else:
        raise ValueError("Unknown dataset")


def is_data_loaded():
    """
    Checks whether a dataset has been loaded.
    :return: boolean
    """
    return data_name is not None


def load_data(name):
    """
    Loads the dataset. Has to be called before anything else.
    :param name: string, the dataset's name
    """

    assert isinstance(name, str), "Name must be a string"
    datasets.root = root_data
    global data, data_name

    if data_name == name:
        return

    if name == "mnist":
        data = datasets.MNIST(logit=True, dequantize=True)
        data_name = name

    elif name == "bsds300":
        data = datasets.BSDS300()
        data_name = name

    elif name == "cifar10":
        data = datasets.CIFAR10(logit=True, flip=True, dequantize=True)
        data_name = name

    elif name == "power":
        data = datasets.POWER()
        data_name = name

    elif name == "gas":
        data = datasets.GAS()
        data_name = name

    elif name == "hepmass":
        data = datasets.HEPMASS()
        data_name = name

    elif name == "miniboone":
        data = datasets.MINIBOONE()
        data_name = name

    else:
        raise ValueError("Unknown dataset")


# endregion


def return_train_test_maf_data(data_list):

    # %%
    # region load params
    split = "tst"  # choose which data split to evaluate on 'trn', 'val' or 'tst''
    n_err = 2  # number of stds in error bars
    bits_per_pixel = False  # whether to use bits/pixel instead of log likelihood (only for image datasets)

    root_results = "results/"  # folder where to save results

    load_data(data_list)
    # endregion

    # %%
    # region fit gaussian model
    def fit_and_evaluate_gaussian(
        split, cond=False, use_image_space=False, return_avg=True
    ):
        """
        Fits a gaussian to the train data and evaluates it on the given split.
        :param split: the data split to evaluate on, must be 'trn', 'val', or 'tst'
        :param cond: boolean, whether to fit a gaussian per conditional
        :param use_image_space: bool, whether to report log probability in [0, 1] image space (only for cifar and mnist)
        :param return_avg: bool, whether to return average log prob with std error, or all log probs
        :return: average log probability & standard error, or all lop probs
        """

        assert is_data_loaded(), "Dataset hasn't been loaded"

        # choose which data split to evaluate on
        data_split = getattr(data, split, None)
        if data_split is None:
            raise ValueError("Invalid data split")

        if cond:
            comps = []
            for i in range(data.n_labels):
                idx = data.trn.labels == i
                comp = pdfs.fit_gaussian(data.trn.x[idx])
                comps.append(comp)
            prior = np.ones(data.n_labels, dtype=float) / data.n_labels
            model = pdfs.MoG(prior, xs=comps)

        else:
            model = pdfs.fit_gaussian(data.trn.x)

        logprobs = model.eval(data_split.x)

        if use_image_space:
            assert utils.af_utils.data_name in ["mnist", "cifar10"]
            z = utils.logistic(data_split.x)
            logprobs += data.n_dims * np.log(1 - 2 * data.alpha) - np.sum(
                np.log(z) + np.log(1 - z), axis=1
            )

        if return_avg:
            avg_logprob = logprobs.mean()
            std_err = logprobs.std() / np.sqrt(data_split.N)
            return avg_logprob, std_err

        else:
            return logprobs

            # # %%
            # res, err = fit_and_evaluate_gaussian(split, cond=False, use_image_space=bits_per_pixel)
            # # res, err = utils.calc_bits_per_pixel(res, err, data.n_dims)
            # print(res, err)

            # # %%
            # from scipy.stats import multivariate_normal

            # model = multivariate_normal
            # mean = np.mean(data.trn.x, axis=0)
            # cov = np.cov(data.trn.x, rowvar=0)
            # data_split = getattr(data, split, None)
            # logprobs = model.logpdf(data_split.x, mean, cov)
            # avg_logprob = logprobs.mean()
            # std_err = logprobs.std() / np.sqrt(data_split.N)
            # print(avg_logprob, std_err)
            # endregion

    return data.trn.x, data.val.x, data.tst.x
