from pathlib import Path
import pandas as pd
import numpy as np
from collections import defaultdict
import argparse
from plotting import plot_error_matrices_and_histograms
import matplotlib.pyplot as plt
import pickle
import notation_config

from hapi_improvements import generate_independent_df

font = {"family": "sans-serif", "size": 12}
plt.rc("font", **font)
plt.rcParams["figure.dpi"] = 300


class DDIAnalysis:
    def __init__(self, dir_name, include_ham10k=False):
        self.skin_tone_split = {  # Split on skin tone (fitzpatrick)
            "Fitzpatrick I & II": lambda df: df.xs(12, level="skin_tone"),
            "Fitzpatrick III & IV": lambda df: df.xs(34, level="skin_tone"),
            "Fitzpatrick V & VI": lambda df: df.xs(56, level="skin_tone"),
            "All": lambda df: df,
        }
        self.ground_truth_split = {  # Split on ground truth (maligancy or benign)
            "malignancy": lambda df: df.xs(1, level="ground_truth"),
            "no malignancy": lambda df: df.xs(0, level="ground_truth"),
        }
        self.column_name_map = {  # To be used by self.agent_split
            "humans": ["derm1", "derm2"],
            "models": ["deepderm", "model_derm"],
        }
        if include_ham10k:
            self.column_name_map["models"].append("ham10k")

        self.agent_split = {  # Split on agent type (model or human or all)
            "models": lambda df: df[self.column_name_map["models"]],
            "humans": lambda df: df[self.column_name_map["humans"]],
            "all": lambda df: df[
                self.column_name_map["models"] + self.column_name_map["humans"]
            ],
        }
        self.accuracy_split = defaultdict(
            dict
        )  # Will be automatically defined during format_ddi_data
        self.confidence_split = defaultdict(
            dict
        )  # Will be automatically defined during format_ddi_data

        self.reference_options = [
            "independence"
        ]  # Options for generating reference distribution. So far we are just using independence
        self.derm_data_folder_path = Path(
            "../data/dermatology/"
        )  # Path to the folder containing the dermatology data
        self.derm_output_folder = (
            Path("../results") / dir_name / "dermatology"
        )  # Path to the folder where results will be saved

        self.Notation = notation_config.NotationConfig()

    """
    Loads the dermatology data and formats it into a dataframe and formats it correctly (see format_ddi_data for more details)
    Make sure the file name is correct. 
    """

    def data_loader(self):
        file_name = (
            "final_outtputs_1_26_2023"  # Make sure this is the correct file name
        )
        derm_data = pd.read_csv(self.derm_data_folder_path / file_name, index_col=0)
        derm_data_formatted = self.format_DDI_data(derm_data)
        return derm_data_formatted

    """
    Given a results_dict of form {str: dataframe}, plots all error matrices and histograms for the dataframes. 
    Each dataframe will be given its own error matrix and correspond to a different color in the grouped histogram.
    Parameters:
        - results_dict (dict): Dictionary of form {str: dataframe} where the dataframe is in the form of (user x model) 
        - dir_name (str): Name of the directory where the results will be saved.
        - dict_is_split (bool): Set to True if 'split_dicts' in split_by_metadata was set to True. Else, set to False.
        - palette (str or list): Name of the palette to be used for the grouped histogram. Can also be a list of colors.
        - orientation ('horizontal' or 'vertical'): orientation of barplot.
    """

    def plot_results(
        self,
        results_dict,
        dir_name,
        dict_is_split=False,
        palette="Set1",
        orientation="horizontal",
    ):
        dir_name.mkdir(parents=True, exist_ok=True)

        all_results_data = defaultdict(dict)
        for key1, df in results_dict.items():
            if dict_is_split:
                color_index = 0
                orig_colors = plt.colormaps[palette].colors
                for key2, result_df in df.items():
                    sample_size = result_df["Observed"].shape[0]
                    title = f"rows={dir_name.parent.name},by={dir_name.name},cols={key1},group={key2}"
                    arxiv_title = title.replace(" ", "_")
                    colors = orig_colors[
                        color_index:
                    ]  # need to cycle the colors outside plot_error_matrcies so that the colors will change across plots
                    if key2.split(" ")[-1] != "sample":
                        color_index += len(result_df)
                    _, _, results_data, _ = plot_error_matrices_and_histograms(
                        result_df,
                        title,
                        additional_plot=None,
                        palette=colors,
                        plot_matrices=False,
                        orientation="vertical",
                        display_title=False,
                        y_max=0.6,
                    )
                    plt.savefig(dir_name / arxiv_title, bbox_inches="tight")

                    # For compatibility with Pandas Dataframes (which we use for visualization), it's nice for the
                    # key to represent the row name and the value to be a dictionary representing the
                    # distribution over system-level outcomes
                    for outcome, probability_mass in results_data.items():
                        all_results_data[
                            (key1, f"{key2} (N={sample_size})", outcome)
                        ] = probability_mass
            else:
                title = f"rows={dir_name.parent.name},by={dir_name.name},keys={key1}"
                arxiv_title = title.replace(" ", "_")
                _, _, results_data, _ = plot_error_matrices_and_histograms(
                    df,
                    title,
                    additional_plot=None,
                    palette=palette,
                    plot_matrices=False,
                    orientation=orientation,
                )
                plt.savefig(dir_name / arxiv_title, bbox_inches="tight")

                for outcome, probability_mass in results_data.items():
                    all_results_data[(key1, outcome)] = probability_mass

            plt.close("all")
            data_file = dir_name / f"results_data"
            with open(data_file.with_suffix(".pickle"), "wb") as handle:
                pickle.dump(all_results_data, handle)
                print(f"Results saved to {handle.name}")

    """ 
    Takes in the outcome matrix of the dermatology data in (user x model) format -- wide not long -- and returns dict of form
    {agent_type: {skin_tone: outcome_matrix}}
    Parameters: 
        - df (pd.dataframe): Dataframe where rows are users, columns are models. 
        - col_split (dict): Should be a dict of form {str: function} where the function takes in a dataframe and returns a subset of its columns (prediction agents).
            Most useful splits have been initialized in init() and can be accessed as class methods. For example, splitting by humans vs models.
        - row_split (dict): Similar to col_split, but the function should return a subset of the rows (instances). 
            For example, returns a subset of rows based on race of input instance.
        - include_ref (bool): If True, includes a reference dataframe for each split. The current reference is 
            outcomes under 'independence'. The reference_df can occasionally have rounding errors; this usually doesn't matter
            but if df has very few rows then rounding can materially affect the proportions shown during plotting.
        - split_dicts: Determines the return form of the result_dict. 
            If True, the return form will be {col_name : {row_name: {"Baseline": expected_df, "Observed": observed_df}}.
            If False, the return form will be {col_name : {row_name_'reference': expected_df, row_name: observed_df}}
    Returns: 
        - results_dict (dict): Nested dict. Outer dict key is name of agent type (see self.agent_split). Value is another dict
        of form {skin_tone: outcome_matrix}
    """

    def split_by_metadata(
        self, ddi_df, col_split, row_split, include_ref=True, split_dicts=False
    ):
        if split_dicts:
            results_dict = defaultdict(lambda: defaultdict(dict))
        else:
            results_dict = defaultdict(dict)
        for row_name, row_split_fn in row_split.items():
            # split_results_dict= defaultdict(dict)
            row_filtered_df = row_split_fn(ddi_df)
            for col_name, col_split_fn in col_split.items():
                col_filtered_df = col_split_fn(row_filtered_df)
                # results_dict[col_name][row_name] = col_filtered_df #eg results_dict['models'] = {12: ..., 34: ...}
                if split_dicts:
                    results_dict[col_name][row_name]["Observed"] = col_filtered_df
                else:
                    results_dict[col_name][row_name] = col_filtered_df

                if include_ref:
                    reference_df = generate_independent_df(col_filtered_df)
                    if split_dicts:
                        results_dict[col_name][row_name]["Baseline"] = reference_df
                    else:
                        results_dict[col_name][f"{row_name}_reference"] = reference_df

        return results_dict

    """ This function will take the raw derm_data matrix as formatted in original paper reformats in format needed for homogenization code
        derm_data comes in 'long' format where each (user, model) prediction and ground truth has its own row. This reformats
        the data into 'outcome matrix' of shape (num_users, models_queried).  Additionally, transform predictions into successes/failures
        and calculate the degree of agreement in predictions within models and within humans  """

    def format_DDI_data(self, derm_data):
        derm_data["is_correct"] = derm_data.apply(
            lambda x: x["algorithm_output"] == x["ground_truth"], axis=1
        )
        derm_data["is_failure"] = derm_data["is_correct"].replace(
            {True: 0, False: 1}
        )  # converts True / False to 1 / 0 because we represent failures as 1s

        derm_data_formatted = derm_data.pivot(
            index=["DDI_ID", "patient_id", "skin_tone", "ground_truth"],
            columns="algorithm_name",
            values=["is_failure", "algorithm_output"],
        )

        # These columns contain information on models failures/successes. This is what will be returned because we failures/successes
        # to calculate model homogenization
        derm_data_failures = derm_data_formatted["is_failure"]

        # These columns contains predictions, not whether the prediction is correct. Used to calculate agreement between models.
        derm_data_predictions = derm_data_formatted["algorithm_output"]

        # Will calculate human and model accuracy and agreement
        for agent_type, column_names in self.column_name_map.items():
            agent_error = derm_data_failures[column_names].mean(axis=1)
            error_key = f"{agent_type}_error"
            derm_data_failures[error_key] = agent_error
            print(f"agent type: {agent_type}, column names: {column_names}")
            # We add a function to self.accuracy_split so we can correctly identify these columns elsewhere in the code
            for accuracy_val in sorted(agent_error.unique()):
                accuracy_name = self.Notation.by_human_accuracy_semantic_dict.get(
                    accuracy_val, accuracy_val
                )
                self.accuracy_split[agent_type][
                    accuracy_name
                ] = lambda df, key=error_key, val=accuracy_val: df[df[key] == val]
                # self.accuracy_split[agent_type][f'{accuracy_name} random sample'] = lambda df, agent_type=agent_type, accuracy_name=accuracy_name: df.sample(n=self.accuracy_split[agent_type][accuracy_name](df).shape[0])

            agent_confidence = derm_data_predictions[column_names].mean(axis=1)
            confidence_key = f"{agent_type}_agreement"
            derm_data_failures[confidence_key] = agent_confidence
            for confidence_val in sorted(agent_confidence.unique()):
                confidence_name = self.Notation.by_human_confidence_semantic_dict.get(
                    confidence_val, confidence_val
                )
                self.confidence_split[agent_type][
                    confidence_name
                ] = lambda df, key=confidence_key, val=confidence_val: df[
                    df[key] == val
                ]

        # We only return failures because we use failures to measure homogenization
        return derm_data_failures

    """
    Splits data by race, ground truth, human agreement, and human error; then plots and saves plotes to subdirectories
    within provided dir_path directory.
    Parameters:
    - df: dataframe containing all dermatology data 
    - dir_path: path to directory where plots will be saved
    - include_ref: whether to include a reference distribution in plots
    """

    def run_metadata_analysis(self, df, dir_path, include_ref=True):
        if include_ref:
            palette = "tab20"
        else:
            palette = "Set1"

        split_dicts = True

        alL_rows = self.split_by_metadata(
            df,
            col_split=self.agent_split,
            row_split={"all_rows": lambda df: df},
            split_dicts=split_dicts,
            include_ref=include_ref,
        )
        no_split_dir = dir_path / "no_split"
        self.plot_results(
            alL_rows,
            no_split_dir,
            palette=palette,
            orientation="vertical",
            dict_is_split=split_dicts,
        )

        by_race_split = self.split_by_metadata(
            df,
            col_split=self.agent_split,
            row_split=self.skin_tone_split,
            split_dicts=split_dicts,
            include_ref=include_ref,
        )
        by_race_dir = dir_path / "by_race"
        self.plot_results(
            by_race_split, by_race_dir, dict_is_split=split_dicts, palette=palette
        )

        try:
            by_ground_truth_split = self.split_by_metadata(
                df,
                col_split=self.agent_split,
                row_split=self.ground_truth_split,
                include_ref=include_ref,
            )
            by_ground_truth_dir = dir_path / "by_ground_truth"
            self.plot_results(
                by_ground_truth_split, by_ground_truth_dir, palette=palette
            )
        except KeyError:
            print("Already have filtered on ground truth")

        models_split = {
            "models": self.agent_split["models"]
        }  # only want to include models since we condition on human accuracy,confidence
        by_confidence_dir = dir_path / "by_confidence"
        by_confidence_split = self.split_by_metadata(
            df,
            col_split=models_split,
            row_split=self.confidence_split["humans"],
            include_ref=include_ref,
        )
        self.plot_results(by_confidence_split, by_confidence_dir, palette=palette)

        by_human_error_dir = dir_path / "by_human_accuracy"
        by_human_error_split = self.split_by_metadata(
            df,
            col_split=models_split,
            row_split=self.accuracy_split["humans"],
            split_dicts=split_dicts,
            include_ref=include_ref,
        )
        self.plot_results(
            by_human_error_split,
            by_human_error_dir,
            dict_is_split=split_dicts,
            palette=palette,
        )

        humans_split = {"humans": self.agent_split["humans"]}
        by_model_error_dir = dir_path / "by_model_accuracy"
        by_model_error_split = self.split_by_metadata(
            df,
            col_split=humans_split,
            row_split=self.accuracy_split["models"],
            include_ref=include_ref,
        )
        self.plot_results(by_model_error_split, by_model_error_dir, palette=palette)


def main(dir_name, include_ham10k=False):
    # dir_path = Path(dir_name)
    ddi_analysis = DDIAnalysis(dir_name, include_ham10k=include_ham10k)

    outcome_df = ddi_analysis.data_loader()

    malignancy_df = ddi_analysis.ground_truth_split["malignancy"](outcome_df)
    benign_df = ddi_analysis.ground_truth_split["no malignancy"](outcome_df)

    all_path = ddi_analysis.derm_output_folder / "all"
    ddi_analysis.run_metadata_analysis(outcome_df, all_path)

    malignancy_path = ddi_analysis.derm_output_folder / "malignancy"
    ddi_analysis.run_metadata_analysis(malignancy_df, malignancy_path)

    benign_path = ddi_analysis.derm_output_folder / "benign"
    ddi_analysis.run_metadata_analysis(benign_df, benign_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dir_name", help="name to be used for naming results folders")
    parser.add_argument(
        "--ham10k", action="store_true"
    )  # if True, will include ham10k in models. Ham10k has recall near 0, so I exclude it often.
    args = parser.parse_args()

    main(args.dir_name, args.ham10k)
