#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
"""

import sys
from pathlib import Path
sys.path.append(str(Path('.').absolute().parent))

from sklearn.neighbors import kneighbors_graph
from sklearn.utils.graph_shortest_path import graph_shortest_path

from src.gwgan.model_mlp import weights_init_generator
from src.sse_sgw_utils import sgw_gpu, sse_gpu, distrib_sse, min_sse, distributional_min_sse
from src.risgw import risgw_gpu


from ot.gromov import gromov_wasserstein
import ot 
from scipy.spatial.distance import cdist

import torch  
import numpy as np

import itertools
from functools import partial
import numbers

from joblib import Parallel, effective_n_jobs
from sklearn.utils.fixes import delayed
from sklearn.utils import gen_even_slices


def my_pairwise_distances(X, Y=None, metric="metric_gw", n_jobs=None, **kwds):
    
    func = partial(_my_pairwise_callable, metric=metric, **kwds)
    return _my_parallel_pairwise(X, Y, func, n_jobs, **kwds)    

#------------------------------------------------

def metric_gw(X0, X1, distance="geodesic"):
    
    if distance == "geodesic":
        C0, C1 = geodesic_distance(X0), geodesic_distance(X1)
    elif distance =="euclidean":
        C0, C1 = cdist(X0, X0), cdist(X1, X1)
    
    C0 /= C0.max()
    C1 /= C1.max()
    
    p0 = ot.unif(C0.shape[0])
    p1 = ot.unif(C1.shape[0])
    
    _, log0 = gromov_wasserstein(C0, C1, p0, p1, 'square_loss', verbose=False, log=True)
    
    return log0['gw_dist']

def metric_min_sse(X0, X1, s_latent2orig_net, t_latent2orig_net, opt_s, opt_t, dim_latent = 10, nproj_dist = 10, max_iter = 50):
    
    X0 = torch.from_numpy(X0.astype(np.float32))
    X1 = torch.from_numpy(X1.astype(np.float32))
    
    s_latent2orig_net.apply(weights_init_generator)
    t_latent2orig_net.apply(weights_init_generator)
    
    return min_sse(X0, X1, s_latent2orig_net, t_latent2orig_net, opt_s, opt_t, dim_latent, nproj_dist, max_iter).numpy()

def metric_distrib_sse(X0, X1, s_net, t_net, opt_s, opt_t, nproj_dist = 10, max_iter = 50):
    
    X0 = torch.from_numpy(X0.astype(np.float32))
    X1 = torch.from_numpy(X1.astype(np.float32))
    
    s_net.apply(weights_init_generator)
    t_net.apply(weights_init_generator)
    
    return distrib_sse(X0, X1, s_net, t_net, opt_s, opt_t, nproj_dist, max_iter).numpy()

def metric_distrib_min_sse(X0, X1, transf_net, s_latent2orig_net, t_latent2orig_net, opt_trannet, opt_s, opt_t, dim_latent, nproj_dist = 10, num_epochs=50, num_sup_iter = 10, num_inf_iter = 10):
    
    X0 = torch.from_numpy(X0.astype(np.float32))
    X1 = torch.from_numpy(X1.astype(np.float32))
    
    s_latent2orig_net.apply(weights_init_generator)
    t_latent2orig_net.apply(weights_init_generator)
    transf_net.apply(weights_init_generator)
    
    
    return distributional_min_sse(X0, X1, transf_net, s_latent2orig_net, t_latent2orig_net, opt_trannet, opt_s, opt_t, 
                                  dim_latent, nproj_dist, num_epochs, num_sup_iter, num_inf_iter).numpy()

def metric_sse(X0, X1, nproj=1000):
    
    X0 = torch.from_numpy(X0.astype(np.float32))
    X1 = torch.from_numpy(X1.astype(np.float32))

    return sse_gpu(X0, X1, device='cpu', nproj=nproj).numpy()

def metric_sgw(X0, X1, nproj=1000):
    
    X0 = torch.from_numpy(X0.astype(np.float32))
    X1 = torch.from_numpy(X1.astype(np.float32))
    
    return sgw_gpu(X0, X1, device='cpu', nproj=nproj).numpy()


def metric_risgw(X0, X1, nproj=1000, lr=0.01, max_iter=500):
    
    X0 = torch.from_numpy(X0.astype(np.float32))
    X1 = torch.from_numpy(X1.astype(np.float32))
    
    return risgw_gpu(X0, X1, device='cpu', nproj=nproj, lr=lr, max_iter=max_iter)

#------------------------------------------------
def geodesic_distance(X, n_neighbors=3, metric = 'minkowski', p = 2, 
                      neighbors_algorithm='auto', path_method='auto', n_jobs=1):
    
    # Computes the (weighted) graph of k-Neighbors for points in X and return distances
    kng = kneighbors_graph(X, n_neighbors, 
                           metric=metric, p=2,
                           mode='distance', n_jobs = n_jobs).toarray()
    # compute geodesic distance matrix
    dist_matrix = graph_shortest_path(kng, method=path_method, directed=False)
    
    return dist_matrix

#------------------------------------------------
def _num_samples(x):
    """Return number of samples in array-like x."""
    
    message = 'Expected sequence or array-like, got %s' % type(x)
    
    if hasattr(x, 'fit') and callable(x.fit):
        # Don't get num_samples from an ensembles length!
        raise TypeError(message)

    if not hasattr(x, '__len__') and not hasattr(x, 'shape'):
        if hasattr(x, '__array__'):
            x = np.asarray(x)
        else:
            raise TypeError(message)

    if hasattr(x, 'shape') and x.shape is not None:
        if len(x.shape) == 0:
            raise TypeError("Singleton array %r cannot be considered"
                            " a valid collection." % x)
        # Check that shape is returning an integer or default to len
        # Dask dataframes may not return numeric shape[0] value
        if isinstance(x.shape[0], numbers.Integral):
            return x.shape[0]

    try:
        return len(x)
    except TypeError as type_error:
        raise TypeError(message) from type_error

        
def _dist_wrapper(dist_func, dist_matrix, slice_, *args, **kwargs):
    """Write in-place to a slice of a distance matrix."""
    dist_matrix[:, slice_] = dist_func(*args, **kwargs)

    
def _my_parallel_pairwise(X, Y, func, n_jobs, **kwds):
    """Break the pairwise matrix in n_jobs even slices
    and compute them in parallel."""

    if Y is None:
        Y = X
        
    if effective_n_jobs(n_jobs) == 1:
        return func(X, Y, **kwds)

    # enforce a threading backend to prevent data communication overhead
    fd = delayed(_dist_wrapper)
    ret = np.empty((X.shape[0], Y.shape[0]), dtype='float', order='F')
    Parallel(backend="threading", n_jobs=n_jobs)(
        fd(func, ret, s, X, Y[s], **kwds)
        for s in gen_even_slices(_num_samples(Y), effective_n_jobs(n_jobs)))

    return ret


def _my_pairwise_callable(X, Y, metric,  **kwds):
    """Handle the callable case for pairwise_{distances,kernels}.
    """
    if Y is None:
        Y = X
        
    if X is Y:
        # Only calculate metric for upper triangle
        out = np.zeros((X.shape[0], Y.shape[0]), dtype='float')
        iterator = itertools.combinations(range(X.shape[0]), 2)
        for i, j in iterator:
            out[i, j] = metric(X[i], Y[j], **kwds)

        # Make symmetric
        # NB: out += out.T will produce incorrect results
        out = out + out.T

        # Calculate diagonal
        # NB: nonzero diagonals are allowed for both metrics and kernels
        for i in range(X.shape[0]):
            x = X[i]
            out[i, i] = metric(x, x, **kwds)

    else:
        # Calculate all cells
        out = np.empty((X.shape[0], Y.shape[0]), dtype='float')
        iterator = itertools.product(range(X.shape[0]), range(Y.shape[0]))
        for i, j in iterator:
            out[i, j] = metric(X[i], Y[j], **kwds)

    return out