from tqdm import tqdm
import pandas as pd
import numpy as np
from pathlib import Path
import os
from entry import Individual, Role
from homogenization import measure_homogenization
from correspondence_configs import JLEConfig, JPEConfig, LabourEconomicsConfig
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter


plt.rcParams.update({"font.size": 11})
configs = {
    "jle": JLEConfig(),
    "jpe": JPEConfig(),
    "laboureconomics": LabourEconomicsConfig(),
}


# Returns pd DataFrame for dataset
def load_dataset(dataset_name):
    path = Path("data")
    dataset_file = dataset_name + ".csv"
    path = path / "correspondence" / dataset_file
    encoding = "cp1252"
    df = pd.read_csv(path, encoding=encoding)
    return df


# Groups rows in [df] by [identifiers]
# Returns:
# [id2entry] - identifiers tuple to associated unique entry
# [index2id] - unique entry index to identifiers
# index_counter - number of unique entries
def group(df, identifiers, metadata_identifiers, data_type):
    identifiers = sorted(identifiers)
    id2entry, index2id = {}, {}
    index_counter = 0
    for _, row in df.iterrows():
        id = tuple([row[identifier] for identifier in identifiers])
        if id not in id2entry:
            index2id[index_counter] = id
            metadata = {
                metadata_identifier: row[metadata_identifier]
                for metadata_identifier in metadata_identifiers
            }
            if data_type == "individual":
                entry = Individual(id, index_counter, identifiers, metadata)
            elif data_type == "role":
                entry = Role(id, index_counter, identifiers, metadata)
            else:
                raise NotImplementedError
            id2entry[id] = entry
            entry.instances += 1
            index_counter += 1
    return id2entry, index2id, index_counter


# Dictionary mapping (individual, role) to where the individual applied to the role and their outcome
# Note: Dictionary keys are the indicies of individual/role rather than individual/role themselves.
def create_AO_dict(
    df,
    individual_id2individual,
    individual_identifiers,
    role_id2role,
    role_identifiers,
    application_identifier,
    outcome_identifier,
):
    individual_identifiers, role_identifiers = sorted(individual_identifiers), sorted(
        role_identifiers
    )
    AO_dict = {}
    duplicate_count = 0
    for _, row in df.iterrows():
        individual_id, role_id = [
            tuple([row[identifier] for identifier in identifiers])
            for identifiers in [individual_identifiers, role_identifiers]
        ]
        individual_index = individual_id2individual[individual_id].index
        role_index = role_id2role[role_id].index
        if (individual_index, role_index) in AO_dict:
            duplicate_count += 1
        AO_dict[(individual_index, role_index)] = (
            application_identifier(row),
            outcome_identifier(row),
        )
    print("Number of duplicates (i.e. same pair of individual-role): ", duplicate_count)
    return AO_dict


# Verifies that the A and O matrices are valid relative to each other
def assert_valid_matrices(A, O):
    assert A.shape == O.shape
    N, k = A.shape
    for i in range(N):
        for j in range(k):
            if A[i, j] == 0:
                assert O[i, j] == 0
            elif A[i, j] == 1:
                assert O[i, j] in {-1, 1}
            else:
                raise ValueError


# Transform the [AO_dict] dicitonary representation to associated N x k matrices
def create_matrices(N, k, AO_dict):
    A, O = np.zeros((N, k)), np.zeros((N, k))
    for (i, j), (application, outcome) in AO_dict.items():
        A[i, j], O[i, j] = application, outcome
    # assert_valid_matrices(A, O)
    return A, O


def basic_results(N, k, M, A, C, print_statements=True):
    callbacks = C.sum()

    A_indiv = np.sum(A, axis=1)
    A_min_indiv, A_max_indiv, A_mean_indiv = (
        np.amin(A_indiv),
        np.amax(A_indiv),
        np.mean(A_indiv),
    )
    A_role = np.sum(A, axis=0)
    A_min_role, A_max_role, A_mean_role = (
        np.amin(A_role),
        np.amax(A_role),
        np.mean(A_role),
    )

    C_indiv = np.sum(C, axis=1)
    C_min_indiv, C_max_indiv, C_mean_indiv = (
        np.amin(C_indiv),
        np.amax(C_indiv),
        np.mean(C_indiv),
    )
    C_role = np.sum(C, axis=0)
    C_min_role, C_max_role, C_mean_role = (
        np.amin(C_role),
        np.amax(C_role),
        np.mean(C_role),
    )

    max_applications = int(A_max_indiv)

    observed_callback_histogram = [0] * (max_applications + 1)
    for num_callbacks in C_indiv:
        num_callbacks = int(num_callbacks)
        observed_callback_histogram[num_callbacks] += 1

    basic_stats = {
        "# individuals": N,
        "# roles": k,
        "# applications": M,
        "# callbacks": callbacks,
        "callback_rate": callbacks / M,
    }
    basic_stats["mean applications"], basic_stats["mean callbacks"] = (
        A_mean_indiv,
        C_mean_indiv,
    )

    if print_statements:
        print(
            "{} individuals applied to {} roles, submitting {} applications of the {} possible for a rate of {}".format(
                N, k, M, N * k, M / (N * k)
            )
        )
        print(
            "Of the {} applications, {} received callbacks with a callback rate of {}".format(
                M, callbacks, callbacks / M
            )
        )
        print(
            "Individuals applied to between {} and {} roles, with an average of {}".format(
                A_min_indiv, A_max_indiv, A_mean_indiv
            )
        )
        print(
            "Roles received between {} and {} applications, with an average of {}".format(
                A_min_role, A_max_role, A_mean_role
            )
        )
        print(
            "Individuals received between {} and {} callbacks, with an average of {}".format(
                C_min_indiv, C_max_indiv, C_mean_indiv
            )
        )
        print(
            "Roles gave between {} and {} callbacks, with an average of {}".format(
                C_min_role, C_max_role, C_mean_role
            )
        )

    return observed_callback_histogram, max_applications, basic_stats


def save_results(dataset_name, split_name, individual_info, role_info, A, O):
    path = Path("results")
    directory = path / dataset_name
    os.makedirs(directory, exist_ok=True)
    raise NotImplementedError


def main(df, dataset_config, t, threshold_type):
    individual_identifiers = dataset_config.individual_identifiers
    individual_metadata_identifiers = dataset_config.individual_metadata_identifiers
    role_identifiers = dataset_config.role_identifiers
    role_metadata_identifiers = dataset_config.role_metadata_identifiers
    application_identifier = dataset_config.application_identifier
    outcome_identifier = dataset_config.outcome_identifier

    M = len(df)
    individual_id2individual, index2individual_id, N = group(
        df, individual_identifiers, individual_metadata_identifiers, "individual"
    )
    individual_info = {
        "id2individual": individual_id2individual,
        "index2id": index2individual_id,
        "identifiers": individual_identifiers,
        "N": N,
    }
    role_id2role, index2role_id, k = group(
        df, role_identifiers, role_metadata_identifiers, "role"
    )
    role_info = {
        "id2role": role_id2role,
        "index2id": index2role_id,
        "identifiers": role_identifiers,
        "k": k,
    }
    AO_dict = create_AO_dict(
        df,
        individual_id2individual,
        individual_identifiers,
        role_id2role,
        role_identifiers,
        application_identifier,
        outcome_identifier,
    )
    A, O = create_matrices(N, k, AO_dict)

    # Merge no application (0) and negative response (-1) outcomes from O in callbacks matrix C.
    C = np.maximum(O, 0)
    # Invert callbacks (1) and {implicit} rejections (0) from C in rejections matrix R.
    R = 1 - C
    # Print basic results
    observed_hist, max_applications, stats = basic_results(N, k, M, A, C)
    # Measure homogenization
    expected_rejection_probs = np.sum(O == -1, axis=0) / np.sum(A, axis=0)
    assert k == len(expected_rejection_probs)

    (
        H,
        obs_over_sample,
        sample_over_exp,
        expected_hist,
        sampled_hist,
    ) = measure_homogenization(
        A, R, N, k, expected_rejection_probs, t, threshold_type, verbose=True
    )
    stats["Homogenization"] = H
    stats["Obs/Sample"] = obs_over_sample
    stats["Sample/Exp"] = sample_over_exp
    # save_results(dataset_name, split_name, individual_info, role_info, A, O)
    return [observed_hist, expected_hist, sampled_hist], max_applications, stats


def visualize(fig, hist, max_applications):
    histogram = []
    for index, count in enumerate(hist):
        count = int(count)
        histogram.extend([index] * count)
    fig.hist(histogram, bins=range(max_applications), density=True)
    fig.set_xticks(range(0, max_applications, 2))


def save_stats(stats_dict, dataset_name, threshold_type, t, row_names):
    column_names = list(stats_dict.keys())
    data = {"Quantity": row_names}
    for column in column_names:
        column_data = []
        for row in row_names:
            column_data.append(stats_dict[column][row])
        data[column] = column_data
    df = pd.DataFrame(data, columns=column_names)
    df.index = row_names
    df.to_csv(
        "visualizations/tables/correspondence/{}/{}={}.csv".format(
            dataset_name, threshold_type, t
        )
    )
    print(df)


if __name__ == "__main__":
    thresholds = [(0, "absolute"), (1, "absolute"), (2, "absolute")]
    # thresholds.extend([(percent / 100, 'percent') for percent in range(0, 31, 1)])

    for dataset_name in tqdm(["laboureconomics"]):
        # for dataset_name in tqdm(['jpe', 'laboureconomics', 'jle']):
        df = load_dataset(dataset_name)
        dataset_config = configs[dataset_name]
        df = df.dropna(
            subset=dataset_config.individual_identifiers
            + dataset_config.role_identifiers
        )

        for t, threshold_type in tqdm(thresholds):
            stats_dict = {}

            df_all = df
            all_histograms, max_applications_all, all_stats = main(
                df_all, dataset_config, t, threshold_type
            )
            observed_hist_all, expected_hist_all, sampled_hist_all = all_histograms

            for partition_name, split_fns in tqdm(
                list(dataset_config.all_splits.items())
            ):
                for split_name, split_fn in split_fns.items():
                    print(dataset_name, partition_name, split_name, threshold_type, t)
                    df_split = split_fn(df)
                    split_histograms, max_applications_split, split_stats = main(
                        df_split, dataset_config, t, threshold_type
                    )
                    (
                        observed_hist_split,
                        expected_hist_split,
                        sampled_hist_split,
                    ) = split_histograms

                    # f, axs = plt.subplots(1, 2, figsize=(12, 6))
                    # ax1, ax2 = axs
                    f, axs = plt.subplots(1, 4, figsize=(16, 6))
                    ax1, ax2, ax3, ax4 = axs
                    xlabel = "Number of callbacks"
                    ylabel = "Percent of individuals"
                    for ax in axs.flat:
                        ax.set(xlabel=xlabel)
                    ax1.set_ylabel(ylabel)
                    ax1.set_title(
                        "Observed callbacks per individual ({})".format(split_name)
                    )
                    ax2.set_title(
                        "Expected callbacks per individual ({})".format(split_name)
                    )
                    ax3.set_title(
                        "Sampled callbacks per individual ({})".format(split_name)
                    )
                    ax4.set_title("Expected callbacks per individual (all)")
                    visualize(ax1, observed_hist_split, max_applications_split)
                    visualize(ax2, expected_hist_split, max_applications_split)
                    visualize(ax3, sampled_hist_split, max_applications_split)
                    visualize(ax4, expected_hist_all, max_applications_all)

                    plt.suptitle(
                        "Individual homogenization in {} ({} individuals by {}). {} threshold of {}.".format(
                            dataset_name, split_name, partition_name, threshold_type, t
                        )
                    )
                    if type(t) is float:
                        t = int(t * 100)
                    f.savefig(
                        "visualizations/figures/correspondence/{}/{}_{}_{}={}".format(
                            dataset_name, partition_name, split_name, threshold_type, t
                        ),
                        dpi=100,
                    )

                    stats_dict[(partition_name, split_name)] = split_stats
                    row_names = list(all_stats.keys())
            save_stats(stats_dict, dataset_name, threshold_type, t, row_names)
