#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import scipy.linalg
from scipy.sparse.linalg import eigsh
from scipy.sparse import coo_matrix
from sklearn.cluster import KMeans
from helpers import get_sparse_laplacian

EPS = 10**-7
eps = np.finfo(float).eps


def contour(laplacian, divisor, k=20, im_size=None, im_size_small=None,
            n_orient=8, verbose=False, pad=0):
    """ turns connectivity maps for local connections into a contour map
    analogous to the spectralPb method delivered with the code for the
    Berkeley Segmentation Database BSD500.

    The original help from MATLAB reads:
    % function [sPb] = spectralPb(mPb, orig_sz, outFile, nvec)
    %
    % description:
    %   global contour cue from local mPb.
    %
    % computes Intervening Contour with BSE code by Charless Fowlkes:
    %
    %http://www.cs.berkeley.edu/~fowlkes/BSE/
    %
    % Pablo Arbelaez <arbelaez@eecs.berkeley.edu>
    % December 2010


    As the inferences drawn with the MRF are immediately given as a
    connectivity matrix we use that one here as a start.
    """
    assert np.prod(im_size_small) == laplacian.shape[0]
    laplacian = divisor @ laplacian.asformat("csc") @ divisor
    eig_w, vectors = eigsh(laplacian, k=k, which='SM', tol=EPS)
    if verbose:
        print('EV problem solved')
    vectors = divisor @ vectors
    # at this point the original MATLAB script reverses the order of eigenvalues
    vect = np.zeros([k, im_size[0], im_size[1]])
    for i_v, vec in enumerate(vectors.T):
        v_small = np.pad(
            np.reshape(vec, im_size_small),
            pad, mode='symmetric')
        if np.any(im_size != v_small.shape):
            vect[i_v] = np.array(Image.fromarray(v_small).resize([im_size[1], im_size[0]]))
        else:
            vect[i_v] = v_small

    # spectral Pb
    # normalize vectors: (skipping 0 vector because it is close to constant)
    vect[1:] -= np.min(np.min(vect[1:], axis=1, keepdims=True), axis=2, keepdims=True)
    vect[1:] /= np.max(np.max(vect[1:], axis=1, keepdims=True), axis=2, keepdims=True)

    # OE parameters
    dtheta = np.pi / n_orient
    ch_per = [4, 3, 2, 1, 8, 7, 6, 5]

    sPb = np.zeros((n_orient, im_size[0], im_size[1]))
    for o in range(n_orient):
        f = oeFilter(sigma=1, support=3, theta=dtheta*o, deriv=1, hil=0)
        for v, w in zip(vect[1:], eig_w[1:]):
            if w > 0:
                v_norm = v / np.sqrt(w)
                sPb[ch_per[o] - 1] += np.abs(scipy.signal.convolve2d(
                    v_norm, f, 'same', boundary='symm'))
    return sPb


def sparse_spectral_clustering(w_map, neighborhood, thresh=np.inf, k=25,
                               plot_on=False):
    """
    runs graph spectral clustering on a single w_map (remember to pass only
    one!)
    Uses L_sym in the nomenclature of Luxburg to get a symmetric positive
    definite matrix
    """
    laplacian, divisor = get_sparse_laplacian(w_map, neighborhood)
    laplacian = divisor @ laplacian.asformat("csc") @ divisor
    w, v = eigsh(laplacian, k=k, which='SM')
    v = divisor @ v
    v_used = v[:, np.abs(w) < thresh]
    kmeans = KMeans(n_clusters=v_used.shape[1])
    kmeans.fit(np.real(v_used))
    if plot_on:
        plt.figure()
        plt.plot(np.sort(np.abs(w)))
        plt.show()
    return kmeans.labels_.reshape((w_map.shape[1], w_map.shape[2]))


def oeFilter(sigma, support=3, theta=0, deriv=9, hil=0, vis=False):
    """ filter for filtering eigenvalues in contour detection
    Original provided with berkeley segmentation code.

    original MATLAB help is:
    % function [f] = oeFilter(sigma,support,theta,deriv,hil,vis)
    %
    % Compute unit L1-norm 2D filter.
    % The filter is a Gaussian in the x direction.
    % The filter is a Gaussian derivative with optional Hilbert
    % transform in the y direction.
    % The filter is zero-meaned if deriv>0.
    %
    % INPUTS
    %	sigma		Scalar, or 2-element vector of [sigmaX sigmaY].
    %	[support]	Make filter +/- this many sigma.
    %	[theta]		Orientation of x axis, in radians.
    %	[deriv]		Degree of y derivative, one of {0,1,2}.
    %	[hil]		Do Hilbert transform in y direction?
    %	[vis]		Visualization for debugging?
    %
    % OUTPUTS
    %	f	Square filter.
    %
    % David R. Martin <dmartin@eecs.berkeley.edu>
    % March 2003
    """

    if np.array(sigma).size == 1:
        sigma = np.array([sigma, sigma])
    if deriv < 0 or deriv > 2:
        raise ValueError('deriv must be in [0, 2]')

    # Calculate filter size; make sure it's odd.
    hsz = np.max(np.ceil(support*sigma))
    sz = 2 * hsz + 1

    # Sampling limits
    maxsamples = 1000  # Max samples in each dimension.
    maxrate = 10       # Maximum sampling rate.
    frate = 10         # Over-sampling rate for function evaluation.

    # Cacluate sampling rate and number of samples.
    rate = min(maxrate, max(1, np.floor(maxsamples/sz)))
    samples = sz * rate

    # The 2D samping grid.
    r = np.floor(sz / 2) + 0.5 * (1 - 1 / rate)
    dom = np.linspace(-r, r, int(samples))
    sx, sy = np.meshgrid(dom, dom)

    # Bin membership for 2D grid points.
    mx = np.round(sx)
    my = np.round(sy)
    membership = (mx+hsz+1) + (my+hsz)*sz

    # Rotate the 2D sampling grid by theta
    su = sx*np.sin(theta) + sy * np.cos(theta)
    sv = sx*np.cos(theta) - sy * np.sin(theta)

    # Evaluate the function separably on a finer grid.
    R = r * np.sqrt(2) * 1.01                   # radius of domain, enlarged by >sqrt(2)
    fsamples = np.ceil(R * rate * frate)        # number of samples
    fsamples = fsamples + ((fsamples + 1) % 2)  # must be odd
    fdom = np.linspace(-R, R, int(fsamples))    # domain for function evaluation
    gap = 2 * R / (fsamples-1)                  # distance between samples

    # The function is a Gaussian in the x direction...
    fx = np.exp(-fdom ** 2/(2 * sigma[0] ** 2))
    # .. and a Gaussian derivative in the y direction...
    fy = np.exp(-fdom ** 2/(2 * sigma[1] ** 2))
    if deriv == 1:
        fy = fy * (-fdom / (sigma[1] ** 2))
    elif deriv == 2:
        fy = fy * (fdom ** 2 / (sigma[1] ** 2) - 1)
    # ...with an optional Hilbert transform.
    if hil:
        fy = np.imag(scipy.signal.hilbert(fy))

    # Evaluate the function with NN interpolation.
    xi = np.round(su/gap) + np.floor(fsamples/2) + 1
    yi = np.round(sv/gap) + np.floor(fsamples/2) + 1
    f = fx[xi.astype(int)] * fy[yi.astype(int)]

    # Accumulate the samples into each bin.
    f = _isum(f.flatten(), membership.flatten().astype('int'), sz * sz)
    # zero mean
    if deriv > 0:
        f -= np.mean(f)
    f = f.reshape([int(sz), int(sz)])

    sumf = np.sum(np.abs(f))
    if sumf > 0:
        f /= sumf
    return f


def _isum(x, idx, nbins):
    """ helper indexsum function.
    Original provided with berkeley segmentation code.

    original MATLAB help is:
    % function acc = isum(x,idx,nbins)
    %
    % Indexed sum reduction, where acc(i) contains the sum of
    % x(find(idx==i)).
    %
    % The mex version is 300x faster in R12, and 4x faster in R13.  As far
    % as I can tell, there is no way to do this efficiently in matlab R12.
    %
    % David R. Martin <dmartin@eecs.berkeley.edu>
    % March 2003
    """
    nbins = int(nbins)
    acc = np.zeros(nbins)
    for i in range(len(x)):
        if idx[i] < 1:
            continue
        elif idx[i] > nbins:
            continue
        else:
            acc[idx[i]-1] += x[i]
    return acc
