
from typing import Dict, Optional, List, Union, Tuple, Callable, Any
import numpy as np
import pandas as pd
import torch
import tqdm
# import config
import os
import pickle
import evaluate
import transformers
import matplotlib

from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder
import matplotlib.pyplot as plt
import seaborn as sns

ROUGE = evaluate.load('rouge')

def check_answer_correctness(
    correct_answers: List[str],
    model_answers: List[str],
    rouge_threshold: float = 0.3,
) -> List[bool]:

    results = [
        res >= rouge_threshold or correct_answer.upper() in model_answer.upper()
        for res, correct_answer, model_answer in zip(
            ROUGE.compute(
                predictions=model_answers,
                references=correct_answers,
                use_aggregator=False,
            )["rougeL"],
            correct_answers,
            model_answers,
        )
    ]

    results = [1 if correct else 0 for correct in results]

    return results

def loop_dataloader(dataloader: DataLoader):
    """
    Loop through a dataloader infinitely.

    Parameters
    ----------
    dataloader: Dataloder
        Dataloader to be looped through.

    Yields
    ------
    batch: Dict[str, Any]
        Batch from dataloader.
    """
    while True:
        for batch in dataloader:
            yield batch

def analyze_results(
    all_confidences,
    all_correctness,
    num_bins,
    save_path,
):
    bins = np.arange(0.0, 1.0, 1.0 / num_bins)
    bins_per_prediction = np.digitize(all_confidences, bins)
    df = pd.DataFrame(
        {
            "y_pred": all_confidences,
            "y": all_correctness,
            "pred_bins": bins_per_prediction,
        }
    )

    grouped_by_bins = df.groupby("pred_bins")
    bin_size = grouped_by_bins.size()
    grouped_bins = grouped_by_bins.mean()

    # calculate the mean y and predicted probabilities per bin
    grouped_bins = grouped_bins["y"].reindex(range(1, num_bins + 1), fill_value=0)
    bin_values = grouped_bins.values

    # calculate the number of items per bin
    bin_sizes = grouped_by_bins["y"].count()
    bin_sizes = bin_sizes.reindex(range(1, num_bins + 1), fill_value=0)

    plt.figure(figsize=(4, 4), dpi=200)
    ax = plt.gca()
    ax.grid(visible=True, axis="both", which="major", linestyle=":", color="grey")
    step_size = 1.0 / num_bins

    # Get bar colors
    bar_colors = []

    # Display the amount of points that fall into each bin via different shading
    total = sum(bin_sizes.values)

    for i, (bin, bin_size) in enumerate(zip(bins, bin_sizes.values)):
        bin_percentage = bin_size / total
        cmap = matplotlib.cm.get_cmap("Blues")
        bar_colors.append(cmap(min(0.9999, bin_percentage + 0.2)))

    plt.bar(
        bins + step_size / 2,
        bin_values,
        width=0.09,
        alpha=0.8,
        color=bar_colors,  # "royalblue",
        edgecolor="black",
    )
    plt.plot(
        np.arange(0, 1 + 0.05, 0.05),
        np.arange(0, 1 + 0.05, 0.05),
        color="black",
        alpha=0.4,
        linestyle="--",
    )

    # Now add the percentage value of points per bin as text
    # if display_percentages:
    total = sum(bin_sizes.values)
    eps = 0.01

    for i, (bin, bin_size) in enumerate(zip(bins, bin_sizes.values)):
        bin_percentage = round(bin_size / total * 100, 2)

        # Omit labelling for very small bars
        # if bin_size == 0:
        #     continue

        plt.annotate(
            f"{bin_percentage} %",
            xy=(bin + step_size / 2, 0.5),
            ha="center",
            va="top",
            rotation=90,
            color="black",
            # color="white" if bin_percentage > 40 else "black",
            alpha=0.7 if bin_percentage > 40 else 0.8,
            fontsize=10,
        )
    ece = cal_ece(all_correctness, all_confidences, num_bins)
    auc = roc_auc_score(all_correctness, all_confidences)
    brier = brier_score_loss(all_correctness, all_confidences)
    acc = sum(all_correctness) / len(all_correctness)

    results = {'acc': acc, 'ece': ece, 'auc': auc, 'brier': brier}

    plt.text(0.4, 0.8, f'ece:{ece:.4f} \nauc:{auc:.4f} \nbrier:{brier:.4f}')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.xlabel("Confidence", fontsize=14, alpha=0.8)
    plt.ylabel("Accuracy", fontsize=14, alpha=0.8)
    plt.tight_layout()

    # save_path = os.path.join(os.path.dirname(save_path), "platt_scaling.png")
    save_path = save_path.replace(".csv", ".png")
    plt.savefig(save_path)

    return results