#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import torch
from torch.nn.functional import pad
from torchvision.transforms import Resize
import numpy as np
from scipy.sparse import csc_matrix, coo_matrix

EPS = 10**-7


def infer_w(feat, neighbors, prec, priorw, coupling):
    """ infers connectivity maps from a feature map and parameters

    Parameters
    ----------
    feat : torch.Tensor
        feature maps to infer from
    neighbors : list or np.ndarray
        neighbor shift indices
    mu : torch.Tensor
        mean of the local factor
    prec : torch.Tensor
        precision of the local factor
    priorw : torch.Tensor
        log-prior ratio for w
    coupling : torch.Tensor
        matrix of log-coupling precisions n_neighbors x n_features

    Returns
    -------
    w_map : numpy.ndarray
        n_neighbor x n_x x n_y matrix of posterior ratios w=1/w=0

    """
    neighbors = np.array(neighbors)
    c = torch.exp(coupling)
    # Derived on paper:
    # factor is sqrt(det(prec))
    # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
    # det(prec) of uncoupled Gaussian is  prec**2
    # thus normalizing factor between the two is:
    # sqrt(prec**2 + 2*prec*C) / prec**2)
    # = sqrt(1 + 2 * C/prec)
    normalizer = torch.sqrt(1 + 2 * c / prec)
    normalizer = normalizer.prod(dim=1)
    w_map = np.full((feat.shape[0], len(neighbors),
                     feat.shape[-2], feat.shape[-1]), np.nan)
    c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
    for i_neigh in range(neighbors.shape[0]):
        fsmall, fshiftsmall = get_fshifts(feat, neighbors[i_neigh])
        fdiff = fshiftsmall - fsmall
        w_mapsmall = np.exp(- torch.sum(c[i_neigh] * (fdiff ** 2), 1) / 2
                            + priorw[i_neigh]) * normalizer[i_neigh]
        # putting w_maps into the right locations:
        if neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] >= 0:
            w_map[:, i_neigh, (w_map.shape[2] - fdiff.shape[2]):,
                  (w_map.shape[3] - fdiff.shape[3]):
                  ] = w_mapsmall
        elif neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] < 0:
            w_map[:, i_neigh, (w_map.shape[2] - fdiff.shape[2]):,
                  :fdiff.shape[3]
                  ] = w_mapsmall
        elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] >= 0:
            w_map[:, i_neigh, :fdiff.shape[2],
                  (w_map.shape[3] - fdiff.shape[3]):
                  ] = w_mapsmall
        elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] < 0:
            w_map[:, i_neigh, :fdiff.shape[2],
                  :fdiff.shape[3]
                  ] = w_mapsmall
    return w_map


def infer_log_w(feat, neighbors, priorw, coupling, prec=1):
    """ infers log connectivity maps from a feature map and parameters

    Parameters
    ----------
    feat : torch.Tensor
        feature maps to infer from
    neighbors : list or np.ndarray
        neighbor shift indices
    mu : torch.Tensor
        mean of the local factor
    priorw : torch.Tensor
        log-prior ratio for w
    coupling : torch.Tensor
        matrix of log-coupling precisions n_neighbors x n_features
    prec : torch.Tensor
        precision of the local factor, default = 1

    Returns
    -------
    w_map : numpy.ndarray
        n_neighbor x n_x x n_y matrix of log-posterior ratios w=1/w=0

    """
    with torch.no_grad():
        neighbors = np.array(neighbors)
        c = torch.exp(coupling)
        # Derived on paper:
        # factor is sqrt(det(prec))
        # det(prec) of coupled Gaussian for one feature is (prec**2 + 2*prec*C)
        # det(prec) of uncoupled Gaussian is  prec**2
        # thus normalizing factor between the two is:
        # sqrt(prec**2 + 2*prec*C) / prec**2)
        # = sqrt(1 + 2 * C/prec)
        normalizer = torch.log(1 + 2 * c / prec) / 2
        normalizer = normalizer.sum(dim=1)
        w_map = np.full((feat.shape[0], len(neighbors),
                         feat.shape[-2], feat.shape[-1]), np.nan)
        c = torch.unsqueeze(torch.unsqueeze(c, -1), -1)
        for i_neigh in range(neighbors.shape[0]):
            fsmall, fshiftsmall = get_fshifts(feat.detach(), neighbors[i_neigh])
            fdiff = fshiftsmall - fsmall
            w_mapsmall = - torch.sum(c[i_neigh] * (fdiff ** 2), 1) / 2 \
                + priorw[i_neigh] + normalizer[i_neigh]
            # putting w_maps into the right locations:
            if neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] >= 0:
                w_map[:, i_neigh, (w_map.shape[2] - fdiff.shape[2]):,
                      (w_map.shape[3] - fdiff.shape[3]):
                      ] = w_mapsmall
            elif neighbors[i_neigh, 0] >= 0 and neighbors[i_neigh, 1] < 0:
                w_map[:, i_neigh, (w_map.shape[2] - fdiff.shape[2]):,
                      :fdiff.shape[3]
                      ] = w_mapsmall
            elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] >= 0:
                w_map[:, i_neigh, :fdiff.shape[2],
                      (w_map.shape[3] - fdiff.shape[3]):
                      ] = w_mapsmall
            elif neighbors[i_neigh, 0] < 0 and neighbors[i_neigh, 1] < 0:
                w_map[:, i_neigh, :fdiff.shape[2],
                      :fdiff.shape[3]
                      ] = w_mapsmall
    return w_map


def get_fshifts(feat, neighbor):
    """ gets the shifted versions of the feature maps such that the featuremap
    at the original point and at the neighbor align."""
    if neighbor[0] >= 0 and neighbor[1] >= 0:
        fshiftsmall = feat[:, :, int(neighbor[0]):,
                           int(neighbor[1]):]
        fsmall = feat[:, :, :(feat.shape[2]-int(neighbor[0])),
                      :(feat.shape[3]-int(neighbor[1]))]
    elif neighbor[0] >= 0 and neighbor[1] < 0:
        fshiftsmall = feat[:, :, int(neighbor[0]):,
                           :(feat.shape[3]-int(-neighbor[1]))]
        fsmall = feat[:, :, :(feat.shape[2] - int(neighbor[0])),
                      int(-neighbor[1]):]
    elif neighbor[0] < 0 and neighbor[1] >= 0:
        fshiftsmall = feat[:, :,
                           :(feat.shape[2] - int(-neighbor[0])),
                           int(neighbor[1]):]
        fsmall = feat[:, :, int(-neighbor[0]):,
                      :(feat.shape[3] - int(neighbor[1]))]
    elif neighbor[0] < 0 and neighbor[1] < 0:
        fshiftsmall = feat[:, :,
                           :(feat.shape[2] - int(-neighbor[0])),
                           :(feat.shape[3]-int(-neighbor[1]))]
        fsmall = feat[:, :, int(-neighbor[0]):,
                      int(-neighbor[1]):]
    return fsmall, fshiftsmall


def get_fullsize(feat, neigh):
    """pads a set of cut feature maps back to the original size

    Parameters
    ----------
    feat : smaller feature maps

    neigh : neighbors

    Returns
    -------
    torch.Tensor: original size feature maps

    """
    return pad(feat, [
        neigh[1] * (neigh[1] > 0), -neigh[1] * (neigh[1] < 0),
        neigh[0] * (neigh[0] > 0), -neigh[0] * (neigh[0] < 0)])


def get_sparse_p(w_map, neighbors):
    """ converts connectivity maps into a sparse matrix of connectivities
    This only copies the values, i.e. if the input maps are log-posteriors
    the matrix will be log-posteriors.

    Parameters
    ----------
    w_map : numpy.ndarray
        w connectivity maps
    neighbors : list or np.ndarray

    Returns
    -------
    edge_mat : sparse matrix ('csc')
        sparse connectivity matrix

    """
    n = w_map.shape[1] * w_map.shape[2]
    dat = []
    i_idx = []
    j_idx = []
    for i in range(w_map.shape[1]):
        for j in range(w_map.shape[2]):
            for k in range(w_map.shape[0]):
                neigh = neighbors[k]
                if (
                        (i >= -neigh[0])
                        and (i < w_map.shape[1] - neigh[0])
                        and (j >= -neigh[1])
                        and (j < w_map.shape[2] - neigh[1])):
                    if np.isfinite(w_map[k, i + neigh[0], j + neigh[1]]):
                        dat.append(w_map[k, i + neigh[0], j + neigh[1]])
                        i_idx.append(w_map.shape[2] * i + j)
                        j_idx.append(w_map.shape[2] * (i + neigh[0]) + j + neigh[1])
    dat_full = list(dat) + list(dat)
    i_idx_full = i_idx + j_idx
    j_idx_full = j_idx + i_idx
    edge_mat = csc_matrix((dat_full, (i_idx_full, j_idx_full)), shape=(n, n))
    return edge_mat


def get_sparse_laplacian(w_map, neighbors):
    """ constructs a sparse laplacian from a stack of w_maps and neighborhood
    information.
    this essentially is edges_to_laplacian(get_sparse_p(x))

    Parameters
    ----------
    w_map : numpy.ndarray
        w connectivity maps
    neighbors : list or np.ndarray

    Returns
    -------
    laplacian : sparse matrix ('csc')
        sparse graph laplacian
    divisor : sparse matrix "divisor", sqrt of the inverse diagonal
    """
    edge_mat = coo_matrix(get_sparse_p(w_map, neighbors))
    return edges_to_laplacian(edge_mat)


def edges_to_laplacian(edge_mat):
    """ computes the graph laplacian from a sparse edge matrix.

    Parameters
    ----------
    edge_mat : sparse matrix ('csc')
        sparse connectivity matrix

    Returns
    -------
    laplacian : sparse matrix ('csc')
        sparse graph laplacian
    divisor : sparse matrix "divisor", sqrt of the inverse diagonal
    """
    n = edge_mat.shape[0]
    diag = np.array(np.sum(edge_mat, axis=1)).flatten()
    edge_mat = coo_matrix(edge_mat)
    dat = list(-edge_mat.data) + list(diag)
    i_idx = list(edge_mat.row) + list(range(n))
    j_idx = list(edge_mat.col) + list(range(n))
    laplacian = coo_matrix((dat, (i_idx, j_idx)), shape=(n, n), dtype=np.float)
    dat_divisor = (1 / np.sqrt(diag.flatten() + EPS)).reshape(n)
    divisor = coo_matrix((dat_divisor, (list(range(n)), list(range(n)))))
    divisor = divisor.asformat("csc")
    return laplacian, divisor


def align_w_maps(w_maps, neighbors, resolution, shifts, subsampling, interpolate=False):
    """
    This function takes a list of w_maps (at different positions and resolutions)
    as input and turns them into a common resolution. values which are not
    set are set to 0.

    Parameters
    ----------
    w_maps : list of np.array or torch.Tensor
        the w_maps for the different resolutions
    neighbors
    resolution : [hight x width]
        target resolution
    shifts : list or np.array
        shift of each w_map relative to the starting corner
        should have dim: number of w_maps x 2
    subsampling : list or np.array
        subsampling factor compared to the original image for each w_map

    Returns
    -------
    w_map: np.array
        the aligned overall w_map
    neighbors: np.array (mx2)
        the neighbor locations in the final resolution pixels

    """
    w_map_out = []
    neigh_out = []
    for w_map, neighbor, shift, subsamp in zip(w_maps, neighbors, shifts, subsampling):
        shift = np.array(shift)
        if shift.size == 1:
            shift = np.array([shift, shift])
        shift = shift.reshape(2)
        subsamp = np.array(subsamp)
        if subsamp.size == 1:
            subsamp = np.array([subsamp, subsamp])
        subsamp = subsamp.reshape(2)
        w_map_big = np.zeros((w_map.shape[0], w_map.shape[1], resolution[0], resolution[1]))
        if shift[0] > 0:
            w_map_big[:, :, :shift[0]] = np.nan
            w_map_big[:, :, -shift[0]:] = np.nan
        if shift[1] > 0:
            w_map_big[:, :, :, -shift[0]:] = np.nan
            w_map_big[:, :, :, :shift[1]] = np.nan
        stop = shift + w_map.shape[-2:] * subsamp
        if interpolate:
            stop = np.minimum(stop, w_map_big.shape[2:])
            interp = Resize(list(stop-shift))
            w_map_t = torch.Tensor(w_map)
            w_map_big[:, :, shift[0]:stop[0], shift[1]:stop[1]] = \
                interp(w_map_t).detach().numpy()
        else:
            w_map_big[:, :, shift[0]:stop[0]:subsamp[0], shift[1]:stop[1]:subsamp[1]] = \
                w_map
        neigh_big = subsamp * np.array(neighbor)
        w_map_out.append(w_map_big)
        neigh_out.append(neigh_big)
    w_map_out = np.concatenate(w_map_out, 1)
    neigh_out = np.concatenate(neigh_out, 0)
    return w_map_out, neigh_out


def extract_cluster_features(cluster, feat, shift, subsampling):
    """
    takes a cluster map in the original image space and gets the features
    within this area. requires the shift and subsampling of the feature space.

    Parameters
    ----------
    cluster : np.array or torch.Tensor
        2D binary map for the cluster in the image
    feat : 3D torch.Tensor or np.array
        feature space to extract from
    shift : int or 2 ints
        shift of the feature map
    subsampling : int or 2 ints
        subsampling factor for the feature map relative to the image

    Returns
    -------
    None.

    """
    shift = np.array(shift)
    if shift.size == 1:
        shift = np.array([shift, shift])
    shift = shift.reshape(2)
    subsampling = np.array(subsampling)
    if subsampling.size == 1:
        subsampling = np.array([subsampling, subsampling])
    subsampling = subsampling.reshape(2)
    cluster_small = cluster[shift[0]::subsampling[0], shift[1]::subsampling[1]]
    if feat.ndim == 4:
        feat = feat[:, :, cluster_small[:feat.shape[2], :feat.shape[3]] > 0]
    elif feat.ndim == 3:
        feat = feat[:, cluster_small[:feat.shape[1], :feat.shape[2]] > 0]
    elif feat.ndim == 2:
        feat = feat[cluster_small[:feat.shape[0], :feat.shape[1]] > 0]
    else:
        raise ValueError('feature vector has strange dimensionality')
    return feat


def get_neigh(n_neigh):
    if n_neigh == 4:
        neigh = [[0, 1], [1, 0]]
    elif n_neigh == 8:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1]]
    elif n_neigh == 12:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1],
                 [2, 0], [0, 2]]
    elif n_neigh == 20:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1],
                 [2, 0], [0, 2],
                 [2, 1], [1, 2], [2, -1], [1, -2]]
    elif n_neigh == 28:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1],
                 [2, 0], [0, 2],
                 [2, 1], [1, 2], [2, -1], [1, -2],
                 [3, 0], [2, 2], [0, 3], [2, -2]]
    elif n_neigh == 40:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1],
                 [2, 0], [0, 2],
                 [2, 1], [1, 2], [2, -1], [1, -2],
                 [3, 0], [2, 2], [0, 3], [2, -2],
                 [4, 0], [3, 2], [0, 4], [3, -2], [2, 3], [2, -3]]
    elif n_neigh == 52:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1],
                 [2, 0], [0, 2],
                 [2, 1], [1, 2], [2, -1], [1, -2],
                 [3, 0], [2, 2], [0, 3], [2, -2],
                 [4, 0], [3, 2], [0, 4], [3, -2], [2, 3], [2, -3],
                 [5, 0], [4, 3], [0, 5], [4, -3], [3, 4], [3, -4]]
    elif n_neigh == 68:
        neigh = [[0, 1], [1, 0],
                 [1, 1], [1, -1],
                 [2, 0], [0, 2],
                 [2, 1], [1, 2], [2, -1], [1, -2],
                 [3, 0], [2, 2], [0, 3], [2, -2],
                 [4, 0], [3, 2], [0, 4], [3, -2], [2, 3], [2, -3],
                 [5, 0], [4, 3], [0, 5], [4, -3], [3, 4], [3, -4],
                 [6, 0], [5, 2], [4, 4], [2, 5], [0, 6], [5, -2], [4, -4], [2, -5]]
    return neigh


def get_last_hparam(file):
    f = open(file, 'r')
    string = f.read()
    last_opt = '{' + string.split(sep='{')[-1]
    hparam = json.loads(last_opt)
    f.close()
    return hparam
