import hapi
from collections import defaultdict
import pandas as pd
import numpy as np
import argparse
import pickle
from pathlib import Path
from homogenization import Homogenization
from plotting import count_combinations, sort_col_row

MODELS_TO_REMOVE = [  # We remove open source models because we want to focus on deployed models that are available through an API from a tech company.
    "vgg19",  # Open source model
    "deepspeech",  # Open source model
    "github",  # Open source model
    "vader",  # Open source model
    "bixin",  # Only has 1 year of data and open source model
]
YEARS = ["20", "21", "22"]
SINGLE_LABEL_TASKS = [
    "sa",
    "scr",
    "fer",
]  # We only conduct analyses for tasks that have a single ground truth label. Focusing on these tasks ensures that each instance is either correct or incorrect, so we don't have to accomodate notions of partial incorrectness.


def main(
    path,
    only_single_label=False,
    homogenization_by_year=False,
    homogenization_by_hardness=False,
    polarization=False,
    leader_following=False,
):
    # Check that YEARS is sequential
    for i, year in enumerate(YEARS):
        if i == 0:
            continue
        if int(year) != int(YEARS[i - 1]) + 1:
            raise Exception("YEARS must be sequential")
    if polarization:
        measure_polarization_on_errors(path, only_single_label)
    if homogenization_by_year or homogenization_by_hardness:
        homogenization_results(
            path,
            only_single_label,
            homogenization_by_year,
            homogenization_by_hardness,
            leader_following,
        )


"""
Calculate homogenization results and save them to specified directory.

Parameters:
- path (str): path to directory to save the results
- homogenization_by_year (boolean): If True, generates "Finding #2" from our paper, "Ecosystem-level Analaysis of Deployed ML ..."
- homogenization_by_hardness (boolean): if true, generates findings in appendix section A.3 in our paper.
- only_single_label: If true, will only generate results for datasets where instances have a single ground truth label.
- leader_following (boolean): If True, uses a slightly different way of formulating 'hardness' whereby hard examples are determined by whatever the best model in a system fails at.

Returns:
None

The function calculates the homogenization results for a range of tasks and datasets using the `calculate_homogenization_results` function. 
It then creates a directory specified by the path parameter, if it does not already exist, and saves the results, error rates, and histograms 
to pickle files in the directory. If the directory already exists, it raises an exception."""


def homogenization_results(
    path,
    only_single_label,
    homogenization_by_year,
    homogenization_by_hardness,
    leader_following,
):
    tests = []
    if homogenization_by_year:
        tests.append("by year")
    if homogenization_by_hardness:
        tests.append("by hardness")

    for test in tests:
        results_dir = Path("../results") / path / test
        results_dir.mkdir(parents=True, exist_ok=True)

        results_title = results_dir / "homogenization_results"
        error_rates_title = results_dir / "error_rates"
        histograms_title = results_dir / "histograms"

        task_dataset_dict = generate_task_dataset_dict(
            only_single_label=only_single_label
        )
        if test == "by year":
            (
                all_results,
                all_error_rates,
                all_histograms,
            ) = calculate_homogenization_by_year(task_dataset_dict)
        elif test == "by hardness":
            (
                all_results,
                all_error_rates,
                all_histograms,
            ) = calculate_homogenization_by_hardness(
                task_dataset_dict, leader_following=leader_following
            )

        with open(results_title.with_suffix(".pickle"), "wb") as handle:
            pickle.dump(all_results, handle)
            print(f"Results saved to {handle.name}")
        with open(error_rates_title.with_suffix(".pickle"), "wb") as handle:
            pickle.dump(all_error_rates, handle)
            print(f"Error rates saved to {handle.name}")
        with open(histograms_title.with_suffix(".pickle"), "wb") as handle:
            pickle.dump(all_histograms, handle)
            print(f"Histograms saved to {handle.name}")


"""
This function will generate the distribution of other-model outcomes for instances conditional on a observing model $h_i$ failing.
This distribution is compared to distribution of other-model outcomes without conditioning on model $h_i$ failing.
Parameters:
- dir_name: the name of the directory to save results to. Results are saved as a dict with the following form:
    {(task, dataset, year, model, model_ranking, simulation_type): {system_level_outcome: probability(sys_level_outcome)}}
- only_single_label: if true, only calculate results tasks with a single ground truth label

"""


def measure_polarization_on_errors(dir_name, only_single_label=True):
    task_dataset_dict = generate_task_dataset_dict(only_single_label=only_single_label)
    results_dir = Path("../results") / dir_name / "polarization"
    results_dir.mkdir(parents=True, exist_ok=True)

    results_title = results_dir / "results"
    polarization_results = {}
    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            R = generate_prediction_df(task, dataset, process=True)
            if R is None:
                continue
            for year in YEARS:
                this_year_preds, this_year_error_rates = generate_year_predictions(
                    R, year, return_numpy=False
                )
                if this_year_preds is None:
                    continue
                for model in this_year_preds:
                    model_ranking = this_year_error_rates.sort_values().index.get_loc(
                        model
                    )

                    simulation_type = "Conditional on model failing"  # Key is a long string because it's convenient when visualizing these results
                    key = (task, dataset, year, model, model_ranking, simulation_type)
                    conditional_on_failure = this_year_preds[model] == 1
                    system_outcome_conditional_on_failure = this_year_preds[
                        conditional_on_failure
                    ].drop(model, axis=1)
                    combination_proportions = generate_combinations_data(
                        system_outcome_conditional_on_failure
                    )
                    polarization_results[key] = combination_proportions

                    simulation_type = "Not conditional on model failing"
                    key = (task, dataset, year, model, model_ranking, simulation_type)
                    system_outcomes_reference = this_year_preds.drop(model, axis=1)
                    combination_proportions_reference = generate_combinations_data(
                        system_outcomes_reference
                    )
                    polarization_results[key] = combination_proportions_reference

    with open(results_title.with_suffix(".pickle"), "wb") as handle:
        pickle.dump(polarization_results, handle)


"""
This function takes in a dataframe of outcomes where each row is an input instance 
and each column is the outcome from a single model. 
It returns a dictionary of the form: {system-level-outcome: probability of system-level-outcome}

"""


def generate_combinations_data(R):
    system_outcomes_sorted, error_rates_sorted, model_ordering = sort_col_row(
        np.array(R)
    )
    unique_vals, _, counts, _ = count_combinations(system_outcomes_sorted)
    proportions = counts / sum(counts)
    combination_proportions = {
        unique_val: proportion
        for (unique_val, proportion) in zip(unique_vals, proportions)
    }
    return combination_proportions


"""
Given a string for the relavant task and dataset, this function identifies models that experience change over time on that dataset.
For models that experience change, it includes the outcome matrices for each year in the outcome_matrices dict, the error rates for each model for each in year in error_rates
and the change in error rates for each model in error_changes_over_time.

Parameters:
- task: the task to analyze (string)
- dataset: the dataset to analyze (string)
- threshold: the threshold for change in error rate. If the error rate for a model changes by more than epsilon, we consider that model to have experienced change over time.

Returns:
- outcome_matrices: a dict of the form {(year: matrix)}
- error_rates: a dict of the form {(year: {model: error_rate})}. This will include error rates for all models and all years, not just models that experience change over time.
- error_changes_over_time: a dict of the form {(model: {(year1, year2): error_rate_change})}. Only includes error rate changes for models that experience change over time.
"""


def identify_change_over_time(task, dataset, threshold=0.005):
    R = generate_prediction_df(task, dataset, process=True)
    outcome_matrices = {}  # (year: matrix)
    error_rates = {}
    error_changes_over_time = defaultdict(dict)

    # This for loop identifies datasets that experience change over year. It adds their rejection matrices to plot
    if R is not None:
        for i in range(len(YEARS) - 1):
            this_year = YEARS[i]
            next_year = YEARS[i + 1]
            this_year_preds, this_year_error_rates = generate_year_predictions(
                R, this_year, return_numpy=False
            )
            next_year_preds, next_year_error_rates = generate_year_predictions(
                R, next_year, return_numpy=False
            )
            if (
                next_year_preds is None
            ):  # If 2021 is missing, check 2022. Might want to make this more general in future
                subequent_year_index = i + 2
                if subequent_year_index < len(YEARS):
                    next_year = YEARS[subequent_year_index]
                    next_year_preds, next_year_error_rates = generate_year_predictions(
                        R, next_year, return_numpy=False
                    )
            if this_year_preds is None or next_year_preds is None:
                continue
            error_rates[this_year] = this_year_error_rates
            error_rates[next_year] = next_year_error_rates
            diff = next_year_error_rates - this_year_error_rates
            for err in diff.items():  # err = (model_name, difference in error)
                if abs(err[1]) > threshold:
                    error_changes_over_time[err[0]][(this_year, next_year)] = err[1]
                    outcome_matrices[this_year] = this_year_preds
                    outcome_matrices[next_year] = next_year_preds
    return outcome_matrices, error_rates, error_changes_over_time


"""
    Generate a dictionary of tasks and datasets from the HAPI repository.
    
    Parameters:
    - task_list (list): list of tasks to include in dictionary. If None, includes all tasks.
    - only_single_label (bool): if true, only adds single_label datasets to dictionary, otherwise adds all datasets.
    
    Returns:
    - task_dataset_dict (dict): dictionary containing the tasks and datasets
    
    The function generates a dictionary of tasks and datasets using the HAPI dataset. It first downloads the HAPI dataset if 
    it does not already exist, and then generates a summary dataframe containing the tasks and datasets. The function then 
    creates a dictionary of tasks and datasets by grouping the datasets by task."""


def generate_task_dataset_dict(task_list=None, only_single_label=False):
    hapi_data_dir_path = Path("../data")
    hapi.config.data_dir = str(hapi_data_dir_path)
    if not (hapi_data_dir_path / "tasks").exists():
        hapi.download()
    df = hapi.summary()

    task_dataset_dict = defaultdict(set)
    if task_list is None:
        tasks = set(df["task"])
    else:
        tasks = task_list
    for task in tasks:
        if only_single_label and task not in SINGLE_LABEL_TASKS:
            print("Skipping multi-label task {}".format(task))
            continue
        task_df = df[df["task"] == task]
        task_datasets = task_df["dataset"].unique()

        task_dataset_dict[task] = task_datasets

    return task_dataset_dict


"""
Generate homogenization results for a range of tasks and datasets.

Parameters:
- task_dataset_dict (dict): dictionary containing the tasks and datasets to generate the results for
- years (list): list of years to calculate the metrics for (default: YEARS)

Returns:
- all_results (dict): dictionary containing the results for all tasks and datasets
- all_error_rates (dict): dictionary containing the error rates for all tasks and datasets
- all_histograms (dict): dictionary containing the histograms for all tasks and datasets

The function generates the homogenization results for a range of tasks and datasets. It first generates the prediction 
dataframe for each task and dataset, filters out the models specified in the models_to_exclude list, and processes the 
prediction data to obtain the rejection data. It then calls the `measure_per_year` function to calculate the 
homogenization metrics for a range of years and stores the results, error rates, and histograms in the all_results, 
all_error_rates, and all_histograms dictionaries, respectively. """


def calculate_homogenization_by_year(task_dataset_dict, years=YEARS):
    all_error_rates = defaultdict(dict)
    all_results = {}
    all_histograms = defaultdict(dict)
    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            R = generate_prediction_df(task, dataset, process=True)
            if R is None:
                continue
            error_rates, histograms = measure_per_year(
                R, years, task, dataset, all_results
            )
            all_error_rates[(task, dataset)] = error_rates
            all_histograms[(task, dataset)] = histograms
    return all_results, all_error_rates, all_histograms


"""
Calculates the grid search over alpha and Delta and saves results to all_results. Additional details on the 
alpha and deltas abstraction is available in homogenization.py.

leader_following: If true, this specifies an alpha and delta regime whereby everything that the most accurate model 
    fails at is expected to be significantly harder for the less accurate models and everything that the most accurate
    model succeeds at is expected to be easier for the less accurate models.
"""


def calculate_homogenization_by_hardness(
    task_dataset_dict,
    leader_following=False,
):
    specified_max_error_rates = []
    delta_lower_bound = 0.2
    delta_upper_bound = 5.0
    delta_step = 0.2

    deltas = [
        round(i * delta_step, 2)
        for i in range(int(delta_lower_bound / 0.2), int(delta_upper_bound / 0.2) + 1)
    ]
    print(deltas)
    year = YEARS[-1]
    alpha_range = (0.1, 0.5)  # hard fraction
    alpha_step = 0.1

    all_results = {}
    all_histograms = defaultdict(dict)
    all_error_rates = defaultdict(dict)

    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            R = generate_prediction_df(task, dataset, process=True)
            if R is None:
                continue
            if dataset == "afnet":
                year = "21"  # afnet lacks full data for 2022
            R_year_np, error_rates, I_year_np = generate_year_predictions(
                R, year, generate_I_year=True
            )
            sorted_error_rates = sorted(error_rates)
            all_error_rates[(task, dataset)] = sorted_error_rates

            num_users_per_model = np.sum(I_year_np, axis=0)
            num_users = np.max(num_users_per_model)
            num_models = np.count_nonzero(num_users_per_model)

            t = 0
            metrics = ["ProdExp"]
            deltas_copy = (
                deltas.copy()
            )  # we create a copy of deltas because we may modify it if leader_following is True or if there are specified_max_error_rates and we only want these modifications to persist within the context of a single dataset
            if leader_following:
                failed_by_best_model = sorted_error_rates[0]
                alpha_range = (
                    failed_by_best_model,
                    failed_by_best_model,
                )  # Alpha will be set to the error rate of the best model
                deltas_copy = [
                    (1 / failed_by_best_model - 1, delta) for delta in deltas_copy
                ]  # Use user provided delta values for all models except best one which we force to have error rate of 1
            homog = Homogenization()
            result = homog.measure_homogenization(
                I_year_np,
                R_year_np,
                num_users,
                num_models,
                list(sorted_error_rates),
                t=t,
                threshold_type="absolute",
                alpha_range=alpha_range,
                alpha_step=alpha_step,
                deltas=deltas_copy,
                specified_max_error_rates=specified_max_error_rates,
                metrics=metrics,
                verbose=True,
            )
            histogram = result.pop("Histograms", None)
            if histogram != None:
                all_histograms[(task, dataset)] = histogram
            all_results[(task, dataset)] = result
    return all_results, all_error_rates, all_histograms


"""
Parameters:
    - task (str): Name of task
    - dataset (str): Name of dataset
    - process (bool): Whether or not to further modify the dataframe such that each row is an instance and each column is a model prediction outcome.
        It also replaces "True" and "False" with 0 and 1, respectively, so that model failures are encoded as a 1. This processing is required when using our homogenization code
    - models_to_exclude: Will drop these models from the prediction dataframe.

Returns: 
    - df_preds(pd.DataFrame): Model predictions and correct or not correct label from Hapi repository.


Generates prediction dataframe from Hapi repository. Calculcated 'is_correct' based on 100% match between prediction and ground truth.
This breaks down in multi-label setting, but is sufficient in the single-label scenario.
"""


def generate_prediction_df(
    task, dataset, process=True, models_to_exclude=MODELS_TO_REMOVE
):
    hapi_labels = hapi.get_labels(task, dataset)
    hapi_label_dict = {
        k["example_id"]: k["true_label"]
        for k in hapi_labels[list(hapi_labels.keys())[0]]
    }

    df_preds = []

    res = hapi.get_predictions(task=task, dataset=dataset)
    for k in res:
        curr_df = pd.DataFrame(res[k])
        curr_df["api"] = k
        df_preds.append(curr_df)

    df_preds = pd.concat(df_preds)
    df_preds["gt_label"] = df_preds.apply(
        lambda x: hapi_label_dict[x["example_id"]], axis=1
    )
    df_preds["is_correct"] = df_preds.apply(
        lambda x: x["predicted_label"] == x["gt_label"], axis=1
    )

    df_preds["date"] = df_preds.apply(lambda x: x["api"].split("/")[-1], axis=1)
    df_preds["year"] = df_preds.apply(lambda x: x["date"].split("-")[0], axis=1)
    df_preds["api"] = df_preds.apply(lambda x: x["api"].split("/")[-2], axis=1)

    providers = [
        a for a in set(df_preds["api"]) if not any(x in a for x in models_to_exclude)
    ]
    df_preds = df_preds[df_preds["api"].isin(providers)]

    if process:
        df_preds = prediction_to_failure_df(df_preds)

    return df_preds


"""
Parameters: df_preds (pd.DataFrame): Prediction dataframe generated from generate_prediction_df function

Returns: A modified version of the prediction dataframe with each row representing a user and each column representing a model outcome 
    (failure: encoded as a 1,  not failure: encoded as a 0).
"""


def prediction_to_failure_df(df_preds):
    try:
        df_failure = df_preds.pivot(
            index=["year", "example_id"], columns="api", values="is_correct"
        )
    except ValueError as ve:
        print("Cannot measure homogenization for task; error message: ", ve)
        return None
    df_failure.replace(
        {True: 0, False: 1}, inplace=True
    )  # converts True / False to 1 / 0 for the game
    return df_failure


def prediction_to_confidence_df(df_preds):
    df_confidence = df_preds.groupby(["year", "example_id"]).agg(
        {"is_correct": np.sum, "confidence": [np.mean, np.std]}
    )
    return df_confidence


""" Takes the dataframe containing predictions for all years and returns a filtered version of the dataframe with only predictions
    from the specified year. Also returns error rates for that year.
Parameters:
- R_all_years (pandas DataFrame): dataframe containing the failure data for all years. A 1 represents the model failed for that user.
- year: a string corresponding to the year of predictions to generate. 
- return_numpy (optional bool): If false, returns as pandas dataframe. If true, returns as numpy array
- generate_I_year (optional bool): If true, will generate the interaction matrix for the given year 

Returns:
- R_year_np (numpy array): numpy array containing the failure data for a single year
- error_rates (list): list containing the error rates each model for that given year 
- if generate_I_year is true, I_year_np: numpy array containing interaction data for a single year"""


def generate_year_predictions(
    R_all_years, year, return_numpy=True, generate_I_year=False
):
    try:
        R_year = R_all_years.loc[year]
    except KeyError:  # possible that we dont have data for that year
        print("No data for year", year)
        return None, None
    R_year = R_year.dropna(
        axis=1, how="all"
    )  # remove columns where all values are NaN (eg. no model interaction)

    # Sum of rejections over number of predictions. Calculates for each API in a given year
    error_rates = R_year.apply(lambda x: x.sum() / x.count())
    if generate_I_year:
        I_year = R_year.replace({0: 1, np.nan: 0}).astype(
            int
        )  # Interaction matrix. np.nan represent no interaction
        if return_numpy:
            R_year = R_year.to_numpy().astype(int)
            I_year = I_year.to_numpy().astype(int)
        return R_year, error_rates, I_year
    else:
        if return_numpy:
            R_year = R_year.to_numpy().astype(int)
        return R_year, error_rates


""" Calculate homogenization metrics for a range of years for MinExp and ProdExp metrics

Parameters:
- R_all_years (pandas DataFrame): dataframe containing the failure data for all years. A 1 represents the model failed for that user.
- years (list): list of years to calculate homogenization for
- task (str): task name
- dataset (str): dataset name
- results_dict (dict): dict to store the results
- verbose (bool): flag to print intermediate results (default: True)
- by_threshold (bool): flag to calculate the metrics for a range of thresholds (default: True)

Returns:
- error_rates_dict (dict): dictionary containing the error rates for each year
- histograms (dict): dictionary containing the histograms for each year

The function calculates the homogenization metrics for a range of years. For each year, it calculates the error rates 
for each model and stores them in the error_rates_dict dictionary. If the by_threshold flag is set to True, it also 
calculates the homogenization metrics for a range of thresholds using the `measure_per_threshold` function and stores 
the results in results_dict. If the by_threshold flag is set to False, it calculates the homogenization 
metrics for threshold 0 and stores the results in the results_dict dataframe.
"""


def measure_per_year(
    R_all_years, years, task, dataset, results_dict, verbose=True, by_threshold=False
):
    error_rates_dict = {}
    histograms = {}
    for year in years:
        R_year_np, error_rates, I_year_np = generate_year_predictions(
            R_all_years, year, generate_I_year=True
        )
        error_rates_dict[year] = error_rates

        num_users_per_model = np.sum(I_year_np, axis=0)
        num_users = np.max(num_users_per_model)
        num_models = np.count_nonzero(num_users_per_model)

        results_key = (task, dataset, year)
        hist_key = year
        if by_threshold:
            results_key = (task, dataset, year)
            measure_per_threshold(
                I_year_np,
                R_year_np,
                num_users,
                num_models,
                error_rates,
                results_dict,  # results saved into results dict
                results_key,
                histograms,
                hist_key,
            )
        else:
            threshold = 0
            metrics = ["MinExp", "ProdExp"]
            homog = Homogenization()
            result = homog.measure_homogenization(
                I_year_np,
                R_year_np,
                num_users,
                num_models,
                list(error_rates),
                t=threshold,
                threshold_type="absolute",
                metrics=metrics,
                verbose=verbose,
            )
            histogram = result.pop("Histograms", None)
            if histogram != None:
                histograms[hist_key] = histogram
            for metric, vals in result.items():
                key = results_key + (threshold, metric)
                results_dict[key] = vals
    return error_rates_dict, histograms


""" Calculate homogenization metrics for a range of thresholds between 0 and (num_models - 1).
    
    Parameters:
    - I (np.array): Interaction matrix
    - R (np.array): Rejection matrix
    - num_users (int): number of users
    - num_models (int): number of models
    - error_rates (list): list of error rates for each model. Should correspond to ordering of columns in R and I
    - transformations (dict): dict of transformations to pass to measure_homogenization (see Homogenization for details)
    - results_dict (dict): dictionary to store the results
    - results_key (tuple): tuple containing keys for the results dictionary.
    - hist_dict (dict): dictionary to store the histograms
    - hist_key (tuple): tuple containing the keys for the histograms dictionary
    
    Returns:
    None. Results are saved in the results_dict that is passed in. The function will append threshold and metric to the 
    passed in results key and then assign the homogenization results as the value to that key.
    
    The function calculates the homogenization metrics ProdExp and MinExp for a range of thresholds. 
    For threshold 0, it also calculates and stores the histograms in the hist_dict dictionary using the hist_key as the key. 
    The results and histograms are stored in the results_dict and hist_dict dictionaries, respectively, using the results_key 
    and hist_key as the keys. """


def measure_per_threshold(
    I,
    R,
    num_users,
    num_models,
    error_rates,
    results_dict,
    results_key,
    hist_dict,
    hist_key,
):
    homog = Homogenization()
    for threshold in range(num_models):
        if threshold == 0:  # Return histograms and calculate ProdExp and MinExp
            metrics = ["MinExp", "ProdExp", "ProdExp_rootk"]
            verbose = True
        else:
            metrics = ["ProdExp", "ProdExp_rootk"]
            verbose = False  # Won't return histogram

        result = homog.measure_homogenization(
            I,
            R,
            num_users,
            num_models,
            list(error_rates),
            metrics=metrics,
            t=threshold,
            threshold_type="absolute",
            verbose=verbose,
            sample_failures=False,
        )
        if result == None:
            key = results_key + (threshold, metric)
            results_dict[key] = {}
            continue

        histogram = result.pop("Histograms", None)
        if histogram != None:
            hist_dict[hist_key] = histogram
        for metric, vals in result.items():
            key = results_key + (threshold, metric)
            results_dict[key] = vals


"""
    Command line has one positional argument, dir_name, which is the name that will be used for the directory within results.
    It also takes an optional argument --single_label. If user specifies --single_label, results will only be calculated for single_label tasks
"""
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dir_name", help="name to be used for naming results folders")
    parser.add_argument("--by_year", action="store_true")
    parser.add_argument("--by_hardness", action="store_true")
    parser.add_argument("--leader_following", action="store_true")
    parser.add_argument("--polarization_on_failures", action="store_true")
    args = parser.parse_args()

    print("Running analysis on hapi datasets...")
    main(
        args.dir_name,
        only_single_label=True,
        homogenization_by_year=args.by_year,
        homogenization_by_hardness=args.by_hardness,
        leader_following=args.leader_following,
        polarization=args.polarization_on_failures,
    )
    print("All results generated")
