from sklearn.metrics.pairwise import pairwise_distances
from copy import copy
from time import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import ripserplusplus as rpp_py

def h1sum(barc):
    if 1 in barc:
        return sum([x[1] - x[0] for x in barc[1]])
    else:
        return 0.0

def pdist_gpu(a, b, device = 'cuda:0'):
    A = torch.tensor(a, dtype = torch.float64)
    B = torch.tensor(b, dtype = torch.float64)

    size = (A.shape[0] + B.shape[0]) * A.shape[1] / 1e9
    max_size = 0.2

    if size > max_size:
        parts = int(size / max_size) + 1
    else:
        parts = 1

    pdist = np.zeros((A.shape[0], B.shape[0]))
    At = A.to(device)

    for p in range(parts):
        i1 = int(p * B.shape[0] / parts)
        i2 = int((p + 1) * B.shape[0] / parts)
        i2 = min(i2, B.shape[0])

        Bt = B[i1:i2].to(device)
        pt = torch.cdist(At, Bt)
        pdist[:, i1:i2] = pt.cpu()

        del Bt, pt
        torch.cuda.empty_cache()

    del At

    return pdist

def sep_dist(a, b, pdist_device = 'cpu'):
    if pdist_device == 'cpu':
        d1 = pairwise_distances(b, a, n_jobs = 40)
        d2 = pairwise_distances(b, b, n_jobs = 40)
    else:
        d1 = pdist_gpu(b, a, device = pdist_device)
        d2 = pdist_gpu(b, b, device = pdist_device)

    s = a.shape[0] + b.shape[0]

    apr_d = np.zeros((s, s))
    apr_d[a.shape[0]:, :a.shape[0]] = d1
    apr_d[a.shape[0]:, a.shape[0]:] = d2

    return apr_d

def barc2array(barc):
    keys = sorted(barc.keys())
    
    arr = []
    
    for k in keys:
        res = np.zeros((len(barc[k]), 2))

        for idx in range(len(barc[k])):
            elem = barc[k][idx]
            res[idx, 0] = elem[0]
            res[idx, 1] = elem[1]
            
        arr.append(res)
        
    return arr
    
def calc_embed_dist(a, b, dim = 1, pdist_device = 'cuda:0', verbose = False, norm = 'median', metric = 'euclidean', use_max = False):
    
    n = a.shape[0]
    
    if pdist_device == 'cpu':
        if verbose:
            print('pdist on cpu start')
        r1 = pairwise_distances(a, a, n_jobs = 40, metric = metric)
        r2 = pairwise_distances(b, b, n_jobs = 40, metric = metric)
    else:
        if verbose:
            print('pdist on gpu start')
        r1 = pdist_gpu(a, a, device = pdist_device)
        r2 = pdist_gpu(b, b, device = pdist_device)
    
    if norm == 'median':
        r1 = r1 / np.median(r1)
        r2 = r2 / np.median(r2)
    elif norm == 'quantile':
        r1 = r1 / np.quantile(r1, 0.9)
        r2 = r2 / np.quantile(r2, 0.9)
       
    if verbose:
        print('pairwise distances calculated')
    
    #
    #  0      r1
    #  r1  min(r1,r2)
    #
    d = np.zeros((2 * n, 2 * n))
    #d[:n, :n] = np.zeros((n, n))
    
    if not use_max:
        d[n:, :n] = r1
        d[:n, n:] = r1
        d[n:, n:] = np.minimum(r1, r2)
    else:
        d[n:, :n] = np.maximum(r1, r2)
        d[:n, n:] = np.maximum(r1, r2)
        d[n:, n:] = r2

    m = d[n:, :n].mean()
    d[d < m*(1e-6)] = 0
    d_tril = d[np.tril_indices(d.shape[0], k = -1)]
    
    if verbose:
        print('matrix prepared')
    
    barc = rpp_py.run("--format lower-distance --dim %d" % dim, d_tril)
    
    return barc
   
def plot_barcodes(arr, color_list = ['deepskyblue', 'limegreen', 'darkkhaki'], dark_color_list = None, title = '', hom = None, ax=None, fig=None):
    
    if ax is None:
        fig, ax = plt.subplots(1)
        plt.rcParams["figure.figsize"] = [6, 4]
        show = True
    else:
        show = False

    if dark_color_list is None:
        dark_color_list = color_list
        #dark_color_list = ['b', 'g', 'orange']

    sh = len(arr)
    step = 0
    if (len(color_list) < sh):
        color_list *= sh

    for i in range(sh):

        if not (hom is None):
            if i not in hom:
                continue

        barc = arr[i].copy()
        arrayForSort = np.subtract(barc[:,1],barc[:,0])

        bsorted = np.sort(arrayForSort)
        nbarc = bsorted.shape[0]
        if show: print('H%d: num barcodes %d' % (i, nbarc))
        if nbarc:
            if show:
                print('max0,976Barcode',i,'=',bsorted[nbarc*976//1000])
                print('maxBarcode',i,'=',bsorted[-1])
                print('middleBarcode',i,'=',bsorted[nbarc//2])
            #print('minbarcode',i,'=',bsorted[0])
            max = bsorted[-3:]

            ax.plot(barc[0], np.ones(2)*step, color = color_list[i], label = 'H{}'.format(i))
            for b in barc:
                if b[1] - b[0] in max :
                    ax.plot(b, np.ones(2)*step, dark_color_list[i])
                else:
                    ax.plot(b, np.ones(2)*step, color = color_list[i])
                step += 1

    ax.set_xlabel('$\epsilon$ (time)')
    ax.set_ylabel('segment')
    ax.set_title(title)
    ax.set_xlim(0, 1)
    ax.legend(loc = 'lower right')
    if show:
        plt.show()

def get_score(elem, h_idx, kind = ''):
    if elem.shape[0] >= h_idx + 1:

        barc = elem[h_idx]
        arrayForSort = np.subtract(barc[:,1], barc[:,0])

        bsorted = np.sort(arrayForSort)

        # number of barcodes
        if kind == 'nbarc':
            return bsorted.shape[0]

        # largest barcode
        if kind == 'largest':
            return bsorted[-1]

        # quantile
        if kind == 'quantile':
            idx = int(0.976 * len(bsorted))
            return bsorted[idx]

        # sum of length
        if kind == 'sum_length':
            return np.sum(bsorted)

        # sum of squared length
        if kind == 'sum_sq_length':
            return np.sum(bsorted**2)

        raise ValueError('Unknown kind of score')

    return 0