import numpy as np
import pandas as pd

from .._plot import plot_frac_pos_vs_scores
from .._utils import save_fig
from ._utils import (binarize_multiclass_marginal, binarize_multiclass_max,
                     compute_calib_metrics, compute_multi_classif_metrics)
from .main import cluster_evaluate_marginals, cluster_evaluate_max


def _get_out_kwargs(clustering, n_bins, ci, name, hist, test_size,
                    calibrate, max_clusters_bin, min_samples_leaf, n_clusters,
                    min_cluster_size, extra_out_kwargs=dict(), order=None):
    "Helper function to build filename from arguments."

    out_kwargs = {
        'clustering': clustering,
        'n_bins': n_bins,
        'ci': ci,
        'hist': hist,
        'test_size': test_size,
        'calibrate': calibrate,
    }

    _order = ['clustering']

    if name is not None:
        out_kwargs['name'] = name
        _order.insert(0, 'name')

    if not isinstance(clustering, str):
        out_kwargs['clustering'] = 'manual'

    elif clustering == 'decision_tree':
        out_kwargs['min_samples_leaf'] = min_samples_leaf
        out_kwargs['max_clusters_bin'] = max_clusters_bin
        _order.append('max_clusters_bin')
        _order.append('test_size')
        _order.append('min_samples_leaf')

    elif clustering == 'kmeans':
        out_kwargs['n_clusters'] = n_clusters
        out_kwargs['min_cluster_size'] = min_cluster_size
        _order.append('n_clusters')

    out_kwargs.update(extra_out_kwargs)

    if order is not None:
        _order = order

    return out_kwargs, _order


def cluster_max(X, y_labels, y_scores, name=None, clustering='kmeans',
                out_dir='img/cluster_max/',
                breakout=False,
                n_bins=15, n_jobs=1, min_samples_leaf=None, n_clusters=2,
                ci='clopper', hist=True, min_cluster_size=10,
                max_clusters_bin=2, test_size=None, plot_cluster_id=False,
                calibrate=False, verbose=2, extra_out_kwargs=dict(), order=None):
    """Evaluate fraction of correct classification in clustered bins of
    maximum scores and plot results.

    Parameters
    ----------
    X : (n, d) array
        The data samples.

    y_labels : (n,) array
        The data labels. Must be integers in {0, ..., K}.

    y_scores : (n, K) array
        The scores given to each of the K classes.

    bins : int or array
        Number of bins or bins.

    clustering : str or (n,) array
        Clustering method to use. Choices: 'kmeans', 'decision_tree' or
        a size (n,) array of cluster assignations (ie all samples with the same
        value belong to the same cluster).

    n_clusters : int
        Number of clusters in each bin. Only used for clustering='kmeans'.

    min_samples_leaf : int
        Parameters passed to DecisionTreeRegressor when
        clustering='decision_tree'. Ignored if max_clusters_bin is not None.

    max_clusters_bin : int
        Compute min_samples_leaf per leaf when clustering='decision_tree'.

    test_size : float or None
        Whether to train/test split data for the clustering. If float given,
        the size of the test set as a propotion. If None: no train/test split.

    breakout : bool
        Whether to breakout per class after binning to run the clustering.

    verbose : int
        Verbosity level.

    n_jobs : int
        Number of jobs to run in parallel. Only used when breakout=True.

    name : str
        Name of the method (only used for filename).

    out_dir : str
        Path to the output directory. Is created if does not exist.

    """

    if y_scores.ndim == 1:
        y_scores = np.stack([1-y_scores, y_scores], axis=1)

    bins = np.linspace(0, 1, n_bins+1)

    (frac_pos,
     counts,
     mean_scores,
     *_
     ) = cluster_evaluate_max(X, y_labels, y_scores, breakout=breakout,
                              bins=n_bins, verbose=verbose,
                              n_jobs=n_jobs,
                              min_samples_leaf=min_samples_leaf,
                              max_clusters_bin=max_clusters_bin,
                              clustering=clustering,
                              n_clusters=n_clusters,
                              test_size=test_size,
                              )

    y_pred_scores, y_well_guess = binarize_multiclass_max(y_scores, y_labels)

    fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores,
                                  y_scores=y_pred_scores,
                                  y_labels=y_well_guess,
                                  ncol=1,
                                  legend_loc='upper left',
                                  xlim_margin=0.05,
                                  ylim_margin=0.05,
                                  min_cluster_size=min_cluster_size,
                                  title=None,
                                  hist=hist,
                                  ci=ci,
                                  mean_only=False,
                                  xlabel='Maximum confidence score',
                                  ylabel='Fraction of correct predictions',
                                  plot_cluster_id=plot_cluster_id,
                                  )

    out_kwargs, order = _get_out_kwargs(
        clustering, n_bins, ci, name, hist, test_size,
        calibrate, max_clusters_bin, min_samples_leaf, n_clusters,
        min_cluster_size, extra_out_kwargs, order)
    out_kwargs['breakout'] = breakout

    fig_path = save_fig(fig, out_dir, order=order, **out_kwargs)

    metrics = {}
    metrics.update(out_kwargs)
    metrics.update(compute_multi_classif_metrics(y_scores, y_labels))
    metrics_binarized = compute_multi_classif_metrics(y_pred_scores, y_well_guess)
    metrics_binarized = {f'binarized_{k}': v for k, v in metrics_binarized.items()}
    metrics.update(metrics_binarized)
    metrics.update(compute_calib_metrics(frac_pos, counts, y_pred_scores, y_well_guess, bins))
    metrics['fig_path'] = fig_path

    return pd.DataFrame([metrics])

def cluster_marginals(X, y_labels, y_scores, name=None,
                      clustering='kmeans',
                      out_dir='img/cluster_marginals/',
                      n_bins=15, n_jobs=1, min_samples_leaf=None, n_clusters=2,
                      ci='clopper', hist=True, min_cluster_size=10,
                      max_clusters_bin=2, test_size=None, plot_cluster_id=False,
                      calibrate=False, verbose=3, extra_out_kwargs=dict(), order=None):
    """Evaluate fraction of predicted class in clustered bins of class scores
    and plot results.

    Parameters
    ----------
    X : (n, d) array
        The data samples.

    y_labels : (n,) array
        The data labels. Must be integers in {0, ..., K}.

    y_scores : (n, K) array
        The scores given to each of the K classes.

    positive_class : int or None
        The one of the K classes to consider as the positive class.
        If None, the output arrays are 3D with K as last dimension.

    clustering : str or (n,) array
        Clustering method to use. Choices: 'kmeans', 'decision_tree' or
        a size (n,) array of cluster assignations (ie all samples with the same
        value belong to the same cluster).

    n_bins : int or array
        Number of bins or bins.

    n_clusters : int
        Number of clusters in each bin.

    min_samples_leaf : int
        Parameters passed to DecisionTreeRegressor when
        clustering='decision_tree'. Ignored if max_clusters_bin is not None.

    max_clusters_bin : int
        Compute min_samples_leaf per leaf when clustering='decision_tree'.

    test_size : float or None
        Whether to train/test split data for the clustering. If float given,
        the size of the test set as a propotion. If None: no train/test split.

    verbose : int
        Verbosity level.

    n_jobs : int
        Number of jobs to run in parallel. Only used when positive_class
        is None.

    name : str
        Name of the method (only used for filename).

    out_dir : str
        Path to the output directory. Is created if does not exist.

    """
    if y_scores.ndim == 1:
        y_scores = np.stack([1-y_scores, y_scores], axis=1)

    n_classes = y_scores.shape[1]

    bins = np.linspace(0, 1, n_bins+1)

    out_kwargs, order = _get_out_kwargs(
        clustering, n_bins, ci, name, hist, test_size,
        calibrate, max_clusters_bin, min_samples_leaf, n_clusters,
        min_cluster_size, extra_out_kwargs, order)

    if not 'class' in order:
        order.insert(1, 'class')

    L_metrics = []
    for k in range(n_classes):

        (frac_pos,
         counts,
         mean_scores,
         *_
         ) = cluster_evaluate_marginals(X, y_labels, y_scores,
                                        positive_class=k,
                                        bins=n_bins,
                                        clustering=clustering,
                                        n_clusters=n_clusters,
                                        test_size=test_size,
                                        min_samples_leaf=min_samples_leaf,
                                        max_clusters_bin=max_clusters_bin,
                                        verbose=verbose,
                                        n_jobs=n_jobs)

        (y_pred_scores,
         y_well_guess,
         ) = binarize_multiclass_marginal(y_scores, y_labels, k)

        fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores,
                                      y_scores=y_pred_scores,
                                      y_labels=y_well_guess,
                                      ncol=1,
                                      legend_loc='upper left',
                                      xlim_margin=0.05,
                                      ylim_margin=0.05,
                                      min_cluster_size=min_cluster_size,
                                      title=None,
                                      hist=hist,
                                      ci=ci,
                                      mean_only=False,
                                      xlabel=f'Confidence score of class {k}',
                                      ylabel=f'Fraction of class {k}',
                                      plot_cluster_id=plot_cluster_id,
                                      )
        out_kwargs['class'] = k
        fig_path = save_fig(fig, out_dir, order=order, **out_kwargs)

        metrics = {}
        metrics.update(out_kwargs)
        metrics.update(compute_multi_classif_metrics(y_scores, y_labels))
        metrics_binarized = compute_multi_classif_metrics(y_pred_scores, y_well_guess)
        metrics_binarized = {f'binarized_{k}': v for k, v in metrics_binarized.items()}
        metrics.update(metrics_binarized)
        metrics.update(compute_calib_metrics(frac_pos, counts, y_pred_scores, y_well_guess, bins))
        metrics['fig_path'] = fig_path
        L_metrics.append(metrics)

    return pd.DataFrame(L_metrics)
