'''Contain other utilities, including class for training and evaluation'''
import pdb
import torch.nn as nn
import torch
import numpy as np
from sklearn.datasets import make_circles, make_moons
import pandas as pd
from PIL import Image
import os
from sklearn.preprocessing import StandardScaler
import itertools as it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''0. For NeuralODE
From torchdiffeq
'''


def _flat_to_shape(tensor, length, shapes):
    tensor_list = []
    total = 0
    for shape in shapes:
        next_total = total + shape.numel()
        # It's important that this be view((...)), not view(...). Else when length=(), shape=() it fails.
        tensor_list.append(
            tensor[..., total:next_total].view((*length, *shape)))
        total = next_total
    return tuple(tensor_list)


class _TupleFunc(torch.nn.Module):
    def __init__(self, base_func, shapes):
        super(_TupleFunc, self).__init__()
        self.base_func = base_func
        self.shapes = shapes

    def forward(self, t, y):
        f = self.base_func(t, _flat_to_shape(y, (), self.shapes))
        return flatten_cat(f)


def _check_inputs(func, y0):
    is_tuple = not isinstance(y0, torch.Tensor)
    shapes = None
    if is_tuple:
        assert isinstance(
            y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
        shapes = [y0_.shape for y0_ in y0]
        y0 = flatten_cat(y0)
        func = _TupleFunc(func, shapes)
    return shapes, func, y0


def flatten_cat(y):
    return torch.cat([y_.reshape(-1) for y_ in y])


'''1. Network related'''


def divergence_bf(dx, x):
    sum_diag = 0.
    if len(x.shape) == 2:
        # Tabular data with shape (N,C)
        for i in range(x.shape[1]):
            sum_diag += torch.autograd.grad(dx[:, i].sum(),
                                            x, create_graph=True)[0][:, i]
    elif len(x.shape) == 3:
        # The last dimension is the feature dimension. Data here is the graph data with shape (N,V,C) for C channels
        # Thus, change in divergence is over ALL nodes per feature dimension
        # IDK if should change it as image example. Can do if results poor
        # On toy two moon, it seems good
        for i in range(x.shape[2]):
            sum_diag += torch.autograd.grad(dx[:, :, i].sum(),
                                            x, create_graph=True)[0][:, :, i]
    else:
        # Image data with shape (N,in_channel,h,w),
        # where in_channel dictates color channels and h,w are height and width
        in_c, h, w, = x.shape[-3], x.shape[-2], x.shape[-1]
        for c, i, j in it.product(*[range(in_c), range(h), range(w)]):
            sum_diag += torch.autograd.grad(dx[:, c, i, j].sum(),
                                            x, create_graph=True)[0][:, c, i, j]
    return sum_diag.view(x.shape[0], 1)


def divergence_approx(dx, x, e_ls=[]):
    '''
        e_ls = [e_1,e_2,...] where
            e_i has the same dimension as x.
            We have multiple e_i to have more accurate estimation of the trace
        It is taken from FFJORD
        Here, e would be generated with i.i.d. Radamacher RVs. as
            e = torch.randint(low=0, high=2, size=x.shape).to(x) * 2 - 1
        It is pre-stored as explained in FFJORD
            "To keep the dynamics deterministic within each call to the ODE solver, we can use a fixed noise vector ε for the duration of each solve without introducing bias:"
    '''
    approx_tr_dzdx_ls = []
    for e in e_ls:
        # See https://blog.csdn.net/waitingwinter/article/details/105774720 for how
        # torch.autograd.grad works
        e_dzdx = torch.autograd.grad(dx, x, e, create_graph=True)[0]
        e_dzdx_e = e_dzdx * e
        approx_tr_dzdx = e_dzdx_e.view(
            x.shape[0], -1).sum(dim=1)  # = batch size
        approx_tr_dzdx_ls.append(approx_tr_dzdx)
    return torch.vstack(approx_tr_dzdx_ls).mean(dim=0)


def divergence_approx_e(dx, x, e_ls=[]):
    # Vectorized version, not really much faster
    e_ls = torch.stack(e_ls, dim=0)
    e_dzdx = torch.autograd.grad(
        dx, x, e_ls, create_graph=True, is_grads_batched=True)[0]
    e_dzdx_e = e_dzdx * e_ls
    approx_trace = e_dzdx_e.sum(dim=2)
    return approx_trace.mean(dim=0)


def reparam_t(args, iter):
    # args.T_ls has length B and stores T_1,...,T_B for B ends of the integrals
    args.T = args.T_ls[iter]


class mySequential(nn.Sequential):
    def forward(self, *input):
        for module in self._modules.values():
            input = module(*input)
        return input


def map_for_or_back(input, num_blocks, FlowNet, args, reverse=True, return_dlogpx=False, refinement=False):
    # Map forward or backward
    with torch.no_grad():
        est_ls = []
        dlogpx_ls = []
        for j in range(num_blocks):
            which_b = num_blocks-1-j if reverse else j
            test = False if reverse else True
            model = FlowNet[which_b]
            model.final_viz = True  # Need as o/w, dlogpx is 0
            reparam_t(args, which_b)  # get args.T based on args.T0
            input_tmp, dlogpx_block_ls = refinement_map(
                model, input, args, test=test, reverse=reverse, refinement=refinement, store_intermediate=True)
            # Setting to False is important, as o/w during later training, loglik of earlier blocks are tracked and thus, takes longer to train
            model.final_viz = False
            input = input_tmp[-1]
            # Because the end of previous block = start of next block, to avoid repetition
            start_idx = 0 if j == 0 else 1
            est_ls.append(input_tmp[start_idx:])
            # How much this block changes log-lik
            dlogpx_ls.append(dlogpx_block_ls[-1].mean())
        est_ls = torch.vstack(est_ls)
        dlogpx_ls = torch.vstack(dlogpx_ls)
        if return_dlogpx:
            return est_ls, dlogpx_ls
        else:
            return est_ls


def refinement_map(model, xinput, args, test=False, reverse=False, refinement=False, store_intermediate=False):
    '''
        If True, no divergence is computed, speeding up stuff.
        This is used when we flow through previous blocks but only trained the "model" at current block
    '''
    if refinement:
        # This is when we use prune-and-refine, breaking the X_b -> X_b+1 to C intermediate points
        # If divide by C, we need to repeat process below C times (then in a loop maybe)
        # This may be used if pruning is too aggressive.
        old_T, old_num_int_pts = args.T, args.num_int_pts
        C = 3
        # print(f'Break [0,T_b] to {C} equal pieces')
        xinput_orig = xinput.clone()
        for c in range(C):
            args.T /= C
            # i.e., how we break [0,T_b/C] to be smaller pieces
            args.num_int_pts = int(args.num_int_pts*0.7)
            predz_mid, dlogpx_mid = model(
                xinput, args, test=test, reverse=reverse)
            xinput = predz_mid[-1]
            # NOTE, somehow need this as o/w the model would have "element 0 has no grad error"
            xinput.requires_grad_(True)
            xinput.retain_grad()
            if c == 0:
                dlogpx = [dlogpx_mid[-1]]
            else:
                dlogpx[-1] += dlogpx_mid[-1]
            if store_intermediate:
                '''
                    This is because we want the visualization of trajectories to plot the intermediate values. For training, we need not do so as we just need the end one.

                    This is ONLY true in "map_for_or_back" above
                '''
                if c == 0:
                    store_mid_pred = [predz_mid]
                else:
                    store_mid_pred += [predz_mid[1:]]
            # Restore, o/w causes error.
            args.T, args.num_int_pts = old_T, old_num_int_pts
        if store_intermediate:
            predz = torch.vstack(store_mid_pred)
        else:
            predz = [xinput_orig, predz_mid[-1]]
    else:
        predz, dlogpx = model(xinput, args, test=test, reverse=reverse)
    return predz, dlogpx


'''2. Data'''


def inf_train_gen(args, train=True):
    train_data_size = args.train_data_size if train else args.test_data_size
    args.w, args.h = 1, 1
    if args.Xdim == 1:
        rng = np.random.RandomState()
        centers = [-3.6, -2.8, -2, 0, 1, 2, 3, 4]
        stds = [0.4, 0.4, 0.4, 0.6, 0.6, 0.6, 0.6, 0.6]
        dataset = []
        for i in range(train_data_size):
            point = rng.randn(1)
            idx = rng.randint(8)
            center = centers[idx]
            std = stds[idx]
            point = point * std + center
            dataset.append(point)
        dataset = np.array(dataset, dtype="float32")
    else:
        if args.word == '':
            dataset, y_vals = make_circles(n_samples=train_data_size, factor=0.5,
                                           noise=0.025)
            dataset = scale_data(dataset, fac=0.6)
        elif args.word == 'two_moon':
            dataset, y_vals = make_moons(noise=0.05,
                                         n_samples=train_data_size, random_state=1103)
            dataset = StandardScaler().fit_transform(dataset)
        else:
            # More complex 2D data, using masks from images
            if len(args.image_masks) == 0:
                if 'img' in args.word:
                    img_masks = np.array(Image.open(args.word).rotate(
                        180).transpose(0).convert('L'))
                    img_masks = [img_masks]
                else:
                    img_masks = []
                    for char in args.word:
                        mask = pd.read_pickle('masks.pkl')
                        mask = mask[mask['letter']
                                    == char.upper()]['mask'].values[0]
                        img_masks.append(mask)
                args.image_masks = img_masks
            train_data_size = int(train_data_size/len(args.image_masks))
            dataset = gen_data_from_img(args, train_data_size)
    X_full = torch.from_numpy(dataset).float()
    if args.color_X:
        # Default FALSE
        return [X_full, torch.from_numpy(y_vals).float()]
    else:
        return X_full


def inf_train_gen_cond_gen(args, train=True):
    # TODO
    '''
        By default,
        X = (N, V, C) (# obs X # nodes X # feature dim )
            If no graph, V=1
        Y = (N, V) (# obs X # nodes, with one label per node)
    '''
    train_data_size = args.train_data_size if train else args.test_data_size
    # TODO Later: add various conditional data AND consider graph networks
    dataset, dataset_y = make_moons(noise=0.05, n_samples=train_data_size)
    dataset = scale_data(dataset, fac=0.6)
    N, V, C = train_data_size, 1, dataset.shape[1]
    dataset, dataset_y = dataset.reshape(N, V, C), dataset_y.reshape(N, V)
    return torch.from_numpy(dataset).float(), torch.from_numpy(dataset_y).float()


def gen_data_from_img(args, train_data_size):
    ''' From FFJORD '''
    def sample_data(train_data_size):
        inds = np.random.choice(
            int(probs.shape[0]), int(train_data_size), p=probs)
        m = means[inds]
        samples = np.random.randn(*m.shape) * std + m
        return samples
    full_data = []
    for i, img in enumerate(args.image_masks):
        h, w = img.shape
        xx = np.linspace(-4, 4, w)
        yy = np.linspace(-4, 4, h)
        xx, yy = np.meshgrid(xx, yy)
        xx = xx.reshape(-1, 1)
        yy = yy.reshape(-1, 1)
        means = np.concatenate([xx, yy], 1)
        img = img.max() - img
        probs = img.reshape(-1) / img.sum()
        std = np.array([8 / w / 2, 8 / h / 2])
        args.h, args.w = h, w
        subdata = sample_data(train_data_size)
        if i > 0:
            subdata[:, 0] += i*8  # distance between the letters
        full_data.append(subdata)
    full_data = np.vstack(full_data)
    if i > 0:
        full_data = scale_data(full_data, 0.5)
    return full_data


def scale_data(data, fac):
    return (data - data.mean(axis=0)) / \
        (fac*data.std(axis=0))


''' 3. Losses '''
# Two-sample tests:


def get_MMD_dict(self, X_test, X_test_hat):
    ################################
    # 2. Get MMD
    # Compute alpha as median of test data
    # Use the median trick
    # See http://www.datasciencecourse.org/notes/nonlinear_modeling/
    N1 = X_test.shape[0]
    X_test, X_test_hat = X_test.view(N1, -1), X_test_hat.view(N1, -1)
    if self.alpha_MMD is None:
        thres = min(N1, 200)
        distances = pdist(X_test[:thres], X_test[:thres], norm=2)
        self.alpha_MMD = 1/(2*torch.median(distances)**2)
    alphas = [torch.tensor(0.5).to(device), self.alpha_MMD]
    ################################
    thres = 2000
    if N1 > thres:
        n_batches = int(N1/thres)
        MMD_metric_dict = {alpha.item(): torch.ones(
            n_batches).to(device) for alpha in alphas}
        # Note, the MMDStatistics class is memory intensive, so we have to break the computation to batches for memory reason
        for i in range(n_batches):
            start = i*thres
            end = (i+1)*thres
            MMD_dict = two_sample_mtd(
                X_test[start:end], X_test_hat[start:end], alphas=alphas, method='MMD')
            for key in MMD_metric_dict.keys():
                MMD_metric_dict[key][i] = MMD_dict[key]
        for key, values in MMD_metric_dict.items():
            MMD_metric_dict[key] = values.mean()

    else:
        MMD_metric_dict = two_sample_mtd(
            X_test, X_test_hat, alphas=alphas, method='MMD')
    self.args.MMD_test = MMD_metric_dict


def two_sample_mtd(x, y, alphas=[1.0], method='MMD'):
    """
        Return the statistics based on input method
        The MMD loss would be the average over these alpha
    """
    N_1, N_2 = x.shape[0], y.shape[0]
    if method == 'MMD':
        mtd = MMDStatistic(N_1, N_2)
        return mtd(x, y, alphas)
    if method == 'Energy':
        mtd = EnergyStatistic(N_1, N_2)
        return mtd(x, y)


def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    r"""Compute the matrix of all squared pairwise distances.
    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2)
                 + norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))
    else:
        dim = sample_1.size(1)
        expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
        expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
        differences = torch.abs(expanded_1 - expanded_2) ** norm
        inner = torch.sum(differences, dim=2, keepdim=False)
        return (eps + inner) ** (1. / norm)


class MMDStatistic:
    r"""The *unbiased* MMD test of :cite:`gretton2012kernel`.
    The kernel used is equal to:
    .. math ::
        k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2},
    for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`.
    Arguments
    ---------
    n_1: int
        The number of points in the first sample.
    n_2: int
        The number of points in the second sample."""

    def __init__(self, n_1, n_2):
        self.n_1 = n_1
        self.n_2 = n_2

        # The three constants used in the test.
        self.a00 = 1. / (n_1 * n_1)
        self.a11 = 1. / (n_2 * n_2)
        self.a01 = - 1. / (n_1 * n_2)

    def __call__(self, sample_1, sample_2, alphas, ret_matrix=False):
        r"""Evaluate the statistic.
        The kernel used is
        .. math::
            k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2},
        for the provided ``alphas``.
        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
            The first sample, of size ``(n_1, d)``.
        sample_2: variable of shape (n_2, d)
            The second sample, of size ``(n_2, d)``.
        alphas : list of :class:`float`
            The kernel parameters.
        ret_matrix: bool
            If set, the call with also return a second variable.
            This variable can be then used to compute a p-value using
            :py:meth:`~.MMDStatistic.pval`.
        Returns
        -------
        :class:`float`
            The test statistic.
        :class:`torch:torch.autograd.Variable`
            Returned only if ``ret_matrix`` was set to true."""
        sample_12 = torch.cat((sample_1, sample_2), 0)
        distances = pdist(sample_12, sample_12, norm=2)
        mmd_dict = {}
        for alpha in alphas:
            kernels = torch.exp(- alpha * distances**2)
            k_1 = kernels[:self.n_1, :self.n_1]
            k_2 = kernels[self.n_1:, self.n_1:]
            k_12 = kernels[:self.n_1, self.n_1:]
            mmd = self.a00 * k_1.sum() + self.a11 * k_2.sum() + 2 * self.a01 * k_12.sum()
            mmd_dict[alpha.item()] = mmd
        if ret_matrix:
            return mmd_dict, kernels
        else:
            return mmd_dict


class EnergyStatistic:
    r"""The energy test of :cite:`szekely2013energy`.

    Arguments
    ---------
    n_1: int
        The number of points in the first sample.
    n_2: int
        The number of points in the second sample."""

    def __init__(self, n_1, n_2):
        self.n_1 = n_1
        self.n_2 = n_2

        self.a00 = - 1. / (n_1 * n_1)
        self.a11 = - 1. / (n_2 * n_2)
        self.a01 = 1. / (n_1 * n_2)

    def __call__(self, sample_1, sample_2, ret_matrix=False):
        r"""Evaluate the statistic.

        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
            The first sample, of size ``(n_1, d)``.
        sample_2: variable of shape (n_2, d)
            The second sample, of size ``(n_2, d)``.
        norm : float
            Which norm to use when computing distances.
        ret_matrix: bool
            If set, the call with also return a second variable.

            This variable can be then used to compute a p-value using
            :py:meth:`~.EnergyStatistic.pval`.

        Returns
        -------
        :class:`float`
            The test statistic.
        :class:`torch:torch.autograd.Variable`
            Returned only if ``ret_matrix`` was set to true."""
        sample_12 = torch.cat((sample_1, sample_2), 0)
        distances = pdist(sample_12, sample_12, norm=2)
        d_1 = distances[:self.n_1, :self.n_1].sum()
        d_2 = distances[-self.n_2:, -self.n_2:].sum()
        d_12 = distances[:self.n_1, -self.n_2:].sum()

        loss = 2 * self.a01 * d_12 + self.a00 * d_1 + self.a11 * d_2

        if ret_matrix:
            return loss, distances
        else:
            return loss


def quick_l2(input):
    '''
        For tensor with shape (N,M1,M2,...),
        We flatten it to be (N,M1*M2*...)
        Then treate it as N vectors to compute l2^2 norm
    '''
    if len(input.size()) > 2:
        return 0.5*input.view(input.shape[0], -1).pow(2).sum(axis=1).mean()
    else:
        return 0.5*input.pow(2).sum(axis=1).mean()


'''4. Others minor ones'''


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def load_saved_checkpoint(loader_args):
    int_mtd, Xdim, reparam_type, word, netname, continuous_param, p = loader_args.int_mtd, loader_args.Xdim, loader_args.reparam_type, loader_args.word, loader_args.netname, loader_args.continuous_param, loader_args.p
    netpre = netname[1:-3]
    if word == '':
        netname = ''
    param_type = '_cont_param' if continuous_param else ''
    filepath = f'JKO_{int_mtd}_{netpre}_Xdim={Xdim}_reparam={reparam_type}{word}{netname}{param_type}_phase{p}'
    return filepath
############
############
############
############
############
############
############
############
