import numpy as np
from scipy.spatial.distance import cdist
from skimage.measure import block_reduce


def reduce_feats(prev, curr):
    assert prev.shape[2] > curr.shape[2]

    kernel_size = (prev.shape[2] // curr.shape[2], prev.shape[3] // curr.shape[3])
    return block_reduce(prev, (1, 1, *kernel_size), np.mean)


def flatten_features(feats):

    return np.transpose(feats, (1, 0, 2, 3)).reshape(feats.shape[1], -1)


def get_feat_corr(features, flattened=False):

    corrs = [None]
    for prev, curr in zip(features, features[1:]):
        if not flattened:
            if prev.shape[2:] != curr.shape[2:]:
                prev = reduce_feats(prev, curr)

            prev_flat = flatten_features(prev)
            curr_flat = flatten_features(curr)
        else:
            prev_flat = prev.T
            curr_flat = curr.T

        corr = np.corrcoef(curr_flat, prev_flat)
        # cor(x, x)  cor(x, y)
        # cor(y, x)  cor(y, y)
        # Get top right (starts after x.shape, x.shape)
        corr = corr[: curr_flat.shape[0], curr_flat.shape[0] :]
        corrs.append(corr)
    return corrs


def get_act_iou(features, threshold, mode="contr"):

    if mode not in {"contr", "inhib"}:
        raise ValueError(mode)

    ious = [None]
    for prev, curr, prev_thresh, curr_thresh in zip(
        features, features[1:], threshold, threshold[1:]
    ):
        if prev.shape[2:] != curr.shape[2:]:
            prev = reduce_feats(prev, curr)
        prev = flatten_features(prev)
        curr = flatten_features(curr)

        curr_acts = curr > curr_thresh[:, np.newaxis]
        prev_acts = prev > prev_thresh[:, np.newaxis]

        if mode == "inhib":
            curr_acts = 1 - curr_acts

        iou = 1 - cdist(curr_acts, prev_acts, metric="jaccard")
        ious.append(iou)
    return ious


def get_act_iou_inhib(features, threshold):

    return get_act_iou(features, threshold, mode="inhib")


def get_weights(modules):

    weights = [m.weight.detach().cpu().numpy() for m in modules[1:]]
    # Take average over receptive field
    weights_mean = [None, *[w.mean(2).mean(2) for w in weights]]
    return weights_mean


def threshold_contributors(weights, n=None, alpha=None, alpha_global=None):

    nones = sum([n is None, alpha is None, alpha_global is None])
    if nones != 2:
        raise ValueError("Must specify exactly one of n, alpha, or alpha_global")

    contr = [None]
    inhib = [None]
    for curr in weights[1:]:

        if n is not None:
            raise NotImplementedError

        else:

            if alpha_global is not None:
                thresholds = np.quantile(
                    curr, [alpha_global, 1 - alpha_global], keepdims=True
                )
            elif alpha is not None:
                thresholds = np.quantile(
                    curr, [alpha, 1 - alpha], axis=1, keepdims=True
                )
            inhib_threshold = thresholds[0]
            contr_threshold = thresholds[1]

        kernel_inhib = curr < inhib_threshold
        kernel_contr = curr > contr_threshold

        inhib.append(kernel_inhib)
        contr.append(kernel_contr)

    return list(zip(contr, inhib))
