import warnings

import numpy as np
import pytest
import torch
from hypothesis import given
from hypothesis.strategies import integers, lists
from sklearn.calibration import calibration_curve as sklearn_calibration_curve
from sklearn.model_selection import train_test_split
from torch.nn import functional as F

from .._plot import plot_frac_pos_vs_scores
from .._utils import save_fig
from ._utils import (MyMaxCalibrator, calibrate_scores, grouping_loss_bias,
                     grouping_loss_lower_bound, piecewise_affine_mapping)
from .main import cluster_evaluate, cluster_evaluate_max


@given(
    test_size=lists(integers(0, 19), unique=True),
)
def test_calibrate(test_size):
    n = 100
    K = 3
    rs = np.random.RandomState(0)
    # y_labels = rs.randint(0, 2, size=n)
    y_scores = rs.uniform(-100, 100, size=(n, K))
    y_scores = np.array(F.softmax(torch.from_numpy(y_scores), dim=1))
    # y_labels = rs.binomial(n=np.ones_like(y_scores, dtype=int), p=y_scores)
    y_labels = rs.binomial(n=K-1, p=0.5, size=n)
    # y_scores = np.clip(y_scores + rs.normal(0, 0.1), 0, 1)
    # y_scores = np.stack([1 - y_scores, y_scores], axis=1)

    # print(y_scores.shape)
    # print(y_labels.shape)

    # print(y_labels)
    # print(y_scores)

    Xt = np.zeros((n, 1))
    # clustering = np.zeros(n, dtype=int)
    y_scores_max = np.max(y_scores, axis=1)
    y_labels_pred = np.argmax(y_scores, axis=1)

    # (frac_pos, counts, mean_scores
    # ) = cluster_evaluate_max(Xt, y_labels, y_scores, clustering=clustering, breakout=False, bins=15)
    y_labels_binarized = (y_labels_pred == y_labels).astype(int)
    frac_pos, counts, mean_scores, *_ = cluster_evaluate(Xt, y_labels_binarized, y_scores_max, bins=15)

    # fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, hist=True, ncol=1)

    # save_fig(fig, '.', cal=False)

    y_scores_cal, test_idx = calibrate_scores(y_scores, y_labels, test_size=test_size)
    assert np.array_equal(np.sort(test_size), test_idx)

    y_scores_cal = y_scores_cal[test_idx]
    Xt = Xt[test_idx, :]
    y_labels_binarized = y_labels_binarized[test_idx]

    # print(y_scores_cal.shape)

    y_scores_max_cal = np.max(y_scores_cal, axis=1)

    # print(y_scores_max_cal.shape)

    frac_pos, counts, mean_scores, *_ = cluster_evaluate(Xt, y_labels_binarized, y_scores_max_cal, bins=15)

    # fig = plot_frac_pos_vs_scores(frac_pos, counts, mean_scores, hist=True, ncol=1)

    # save_fig(fig, '.', cal=True)


@pytest.mark.parametrize('n', [1, 2, 10])
@pytest.mark.parametrize('k', [0, 1, 2, 10])
def test_piecewise_affine_mapping(n, k):
    # n = 50
    # k = 3
    rs = np.random.RandomState(0)
    xs = np.concatenate([[0], rs.uniform(size=k), [1]])
    ys = np.concatenate([[0], rs.uniform(size=k), [1]])
    x = rs.uniform(size=n)
    piecewise_affine_mapping(x, xs, ys)


def test_maxcalibrator():

    n_bins = 15
    calibrator = MyMaxCalibrator(n_bins=n_bins)
    n = 1000
    K = 10
    rs = np.random.RandomState(0)
    # y_labels = rs.randint(0, 2, size=n)
    y_scores = rs.uniform(-100, 100, size=(n, K))
    y_scores = np.array(F.softmax(torch.from_numpy(y_scores), dim=1))
    # y_labels = rs.binomial(n=np.ones_like(y_scores, dtype=int), p=y_scores)
    y_labels = rs.binomial(n=K-1, p=0.5, size=n)
    # y_scores = np.clip(y_scores + rs.normal(0, 0.1), 0, 1)
    # y_scores = np.stack([1 - y_scores, y_scores], axis=1)

    y_scores_cal, y_scores, y_labels_cal, y_labels = train_test_split(y_scores, y_labels, train_size=0.1)

    calibrator.fit(y_scores_cal, y_labels_cal)

    y_scores_calibrated = calibrator.predict(y_scores)

    y_pred = np.argmax(y_scores, axis=1)
    y_scores_max = np.max(y_scores, axis=1)

    y_binarized = (y_pred == y_labels).astype(int)
    prob_true, prob_pred = sklearn_calibration_curve(y_binarized, y_scores_max,
                                                        n_bins=n_bins)

    y_scores_calibrated_max = np.max(y_scores_calibrated, axis=1)
    prob_true_cal, prob_pred_cal = sklearn_calibration_curve(y_binarized, y_scores_calibrated_max,
                                                        n_bins=n_bins)

    import matplotlib.pyplot as plt

    # plt.plot(xs, ys)
    plt.plot(prob_pred, prob_true, color='blue')
    plt.plot(prob_pred_cal, prob_true_cal, color='red')
    plt.show()



@pytest.mark.parametrize('n', [10, 100])
def test_grouping_loss_bias(n):
    d = 2
    rs = np.random.RandomState(0)
    X = rs.uniform(size=(n, d))
    y_labels = rs.randint(0, 2, size=n)
    y_scores = rs.uniform(size=n)
    n_bins = 10

    bins = np.linspace(0, 1, n_bins+1)

    # Compute calibration curve from cluster_evaluate results
    frac_pos, counts, *_ = cluster_evaluate(X, y_labels, y_scores,
                                           bins=bins)

    bias_bin = grouping_loss_bias(frac_pos, counts, reduce_bin=False)
    bias = grouping_loss_bias(frac_pos, counts, reduce_bin=True)

    print(bias_bin)
    print(bias)

    assert np.allclose(np.nansum(bias_bin), bias)

    lower_bound_bin = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=False, debiased=True)
    lower_bound = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=True, debiased=True)

    assert np.allclose(np.sum(lower_bound_bin), lower_bound)


@pytest.mark.parametrize('n', [10, 100])
def test_grouping_loss_bias_nans(n):
    d = 2

    frac_pos = np.full((3, 2), 0.5)
    counts = np.full_like(frac_pos, 2)
    counts[0, 0] = 1
    counts[2, :] = 0

    with warnings.catch_warnings():
        warnings.simplefilter("error")
        bias_bin = grouping_loss_bias(frac_pos, counts, reduce_bin=False)
        bias = grouping_loss_bias(frac_pos, counts, reduce_bin=True)

    print(bias_bin)
    print(bias)

    assert np.allclose(np.nansum(bias_bin), bias)

    lower_bound_bin = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=False, debiased=True)
    lower_bound = grouping_loss_lower_bound(frac_pos, counts, reduce_bin=True, debiased=True)

    assert np.allclose(np.sum(lower_bound_bin), lower_bound)
