import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from helpers.logger import get_logger

logger = get_logger()


def match_predictors(
        predictions: torch.Tensor
) -> torch.Tensor:
    """ Match predictors across trials.

    Reorganise predictions such that correlating estimators are lined up across trials. (Afterwards, estimator i
    of trial j will correlate with estimator i of all other trials.)

    Args:
        predictions (torch.Tensor): 3D tensor of predictions [ trials x estimators x samples ]

    Returns:
        torch.Tensor: re-organised predictions [ trials x estimators x samples ]
    """
    num_trials = predictions.shape[0]
    num_predictors = predictions.shape[1]
    ignore_val = -10
    idx = torch.zeros(num_trials, num_predictors, dtype=torch.int)

    # initialise prototype results with first trial
    prototype = predictions[0, :, :]
    idx[0, :] = torch.arange(num_predictors)

    for t in range(1, num_trials):
        # find pairing for trial t compared to prototype
        agreement = torch.concat((prototype, predictions[t, :, :]), dim=0).corrcoef()
        relevant = agreement[0:num_predictors, num_predictors:]
        relevant = torch.nan_to_num(relevant, nan=1.0)
        for p in range(num_predictors):
            proto_idx, trial_idx = (relevant == torch.max(relevant)).nonzero()[0]
            # logger.info('Trial %d: Matching %d and %d with %.3f correlation', t, proto_idx, trial_idx, torch.max(relevant))
            if proto_idx != trial_idx:
                logger.info('Trial %d EXCEPTION: Matching %d and %d with %.3f correlation',
                            t, proto_idx, trial_idx, torch.max(relevant))
            idx[t, proto_idx] = trial_idx
            relevant[proto_idx, :] = ignore_val
            relevant[:, trial_idx] = ignore_val

        # update prototype results
        trial_predictions = predictions[t][idx[t], :]
        prototype = (t * prototype + trial_predictions) / (t + 1)

    # re-organise predictions based on generated index
    new_predictions = torch.zeros_like(predictions)
    for t in range(num_trials):
        new_predictions[t] = predictions[t][idx[t]]
    return new_predictions


def _test_allocations():
    """Test script. Not for external use."""
    num_trials = 20
    num_estimators = 5
    num_samples = 10
    base = np.random.rand(num_estimators, num_samples)
    test_predictions = np.zeros([num_trials, num_estimators, num_samples])
    test_predictions[0] = base
    correctly_matched = np.zeros([num_trials, num_estimators, num_samples])
    correctly_matched[0] = base
    for t in range(1, num_trials):
        permutation = np.random.permutation(num_estimators)
        noised = base + np.random.rand(num_estimators, num_samples) / 100
        test_predictions[t] = noised[permutation]
        correctly_matched[t] = noised
    test_predictions = torch.from_numpy(test_predictions)
    auto_matched = match_predictors(test_predictions).cpu().numpy()
    delta = np.sum(np.abs(auto_matched - correctly_matched))
    if delta > np.finfo(float).eps:
        logger.error('Error matching predictions: %f', delta)
    else:
        logger.info('match_predictors passed unit test')
    logger.info('Done')


def plot_predictions(
        predictions: torch.Tensor,
        capacity: int,
        num_samples: int,
        dir: str):
    """ Plot predictions per sample per estimator across trials.

    Args:
        predictions (torch.Tensor): 3D tensor of predictions [ trials x estimators x samples ]
        capacity (int): capacity identifier, for example, number of estimators
        num_samples (int): number of samples to plot
        dir (str): directory to save plots to

    """
    fig, ax = plt.subplots(figsize=(12, 8))
    num_predictors = predictions.shape[1]
    for s in range(num_samples):
        color = iter(plt.cm.rainbow(np.linspace(0, 1, num_predictors)))
        for p in range(num_predictors):
            c = next(color)
            values = predictions[:, p, s].flatten().detach().cpu().numpy()
            sample_ids = s * np.ones(len(values), dtype=int)
            if s == 0:
                ax.scatter(sample_ids, values, label=str(p), color=c, marker='.', s=8)
            else:
                ax.scatter(sample_ids, values, color=c, marker='.', s=8)

    ax.set_xlabel("Sample ID")
    ax.set_ylabel("Sub-predictor prediction")
    ax.set_title("Variance among sub-predictors across trials")
    ax.legend()
    path = os.path.join(dir, 'plot_values_' + str(capacity) + '.png')
    plt.grid()
    plt.savefig(path, dpi=300)
    plt.close()

# If adding predictions as an additional tensor when generating results
# for k in range(len(results)):
#     capacity = results[k][3]
#     predictions = results[k][4]
#     plot_predictions(predictions, capacity, num_samples, dir)
#
# --------------------------------------------------------------------------------------------------------------------#
