from pathlib import Path
import pandas as pd
import numpy as np
from collections import defaultdict
import argparse
from plotting import plot_error_matrices_and_histograms
import matplotlib.pyplot as plt
import pickle
import notation_config

from hapi_matrix_graphs import generate_independent_df

font = {"family": "sans-serif", "size": 12}
plt.rc('font', **font)
plt.rcParams['figure.dpi'] = 300


class DDIAnalysis:
    def __init__(self, include_ham10k=False):
        self.skin_tone_split = { # Split on skin tone (fitzpatrick)
            "Fitzpatrick I & II": lambda df: df.xs(12, level='skin_tone'),
            "Fitzpatrick III & IV": lambda df: df.xs(34, level='skin_tone'),
            'Fitzpatrick V & VI': lambda df: df.xs(56, level='skin_tone'),
            'All': lambda df: df
        }
        self.ground_truth_split = {  # Split on ground truth (maligancy or benign)
            "malignancy": lambda df: df.xs(1, level="ground_truth"),
            "no malignancy": lambda df: df.xs(0, level="ground_truth"),
        }
        self.column_name_map = {  # To be used by self.agent_split
            "humans": ["derm1", "derm2"],
            "models": ["deepderm", "model_derm"],
        }
        if include_ham10k:
            self.column_name_map["models"].append("ham10k")

        self.agent_split = {  # Split on agent type (model or human or all)
            "models": lambda df: df[self.column_name_map["models"]],
            "humans": lambda df: df[self.column_name_map["humans"]],
            "all": lambda df: df[
                self.column_name_map["models"] + self.column_name_map["humans"]
            ],
        }
        self.accuracy_split = defaultdict(
            dict
        )  # Will be automatically defined during format_ddi_data
        self.confidence_split = defaultdict(
            dict
        )  # Will be automatically defined during format_ddi_data

        self.reference_options = [
            "independence"
        ]  # Options for generating reference distribution. So far we are just using independence
        self.derm_data_folder_path = Path(
            "../data/dermatology/"
        )  # Path to the folder containing the dermatology data
        self.derm_output_folder = Path(
            "../results/dermatology"
        )  # Path to the folder where results will be saved
        
        self.Notation = notation_config.NotationConfig()

    """
    Loads the dermatology data and formats it into a dataframe and formats it correctly (see format_ddi_data for more details)
    Make sure the file name is correct. 
    """

    def data_loader(self):
        file_name = (
            "final_outtputs_1_26_2023"  # Make sure this is the correct file name
        )
        derm_data = pd.read_csv(self.derm_data_folder_path / file_name, index_col=0)
        derm_data_formatted = self.format_DDI_data(derm_data)
        return derm_data_formatted


    """
    Given a results_dict of form {str: dataframe}, plots all error matrices and histograms for the dataframes. 
    Each dataframe will be given its own error matrix and correspond to a different color in the grouped histogram.
    Parameters:
        - results_dict (dict): Dictionary of form {str: dataframe} where the dataframe is in the form of (user x model) 
        - dir_name (str): Name of the directory where the results will be saved.
        - palette (str or list): Name of the palette to be used for the grouped histogram. Can also be a list of colors.
    """

    def plot_results(self, results_dict, dir_name, dict_is_split=False, palette="Set1", orientation="horizontal"):
        dataset_dir = self.derm_output_folder / dir_name
        dataset_dir.mkdir(parents=True, exist_ok=True)
        
        all_results_data = defaultdict(dict)
        for key1, df in results_dict.items():
            if dict_is_split:
                color_index = 0
                # orig_colors = get_cmap(palette).colors
                orig_colors = plt.colormaps[palette].colors
                for key2, result_df in df.items():
                    sample_size = result_df["Observed"].shape[0]
                    title = f"rows={dir_name.parent.name},by={dir_name.name}, cols={key1}, group={key2}"
                    colors = orig_colors[color_index:] #need to cycle the colors outside plot_error_matrcies so that the colors will change across plots
                    if key2.split(' ')[-1] != 'sample':
                        color_index += len(result_df)
                    _, _, results_data, _ = plot_error_matrices_and_histograms(result_df, title, additional_plot=None, palette=colors, plot_matrices=False, orientation='vertical', display_title=False, y_max=0.6)
                    plt.savefig(dataset_dir / title, bbox_inches='tight')

                    # For compatibility with Pandas Dataframes (which we use for visualization), it's nice for the 
                    # key to represent the row name and the value to be a dictionary representing the 
                    # distribution over system-level outcomes
                    for outcome, probability_mass in results_data.items():
                        all_results_data[(key1, f"{key2} (N={sample_size})", outcome)] = probability_mass
            else:
                title = f"rows={dir_name.parent.name},by={dir_name.name}, keys={key1} "
                _, _, results_data, _ = plot_error_matrices_and_histograms(df, title, additional_plot=None, palette=palette, plot_matrices=False, orientation=orientation)
                plt.savefig(dataset_dir / title, bbox_inches='tight')

                for outcome, probability_mass in results_data.items():
                    all_results_data[(key1, outcome)] = probability_mass



            plt.close('all')
            data_file = dataset_dir / f'results_data'
            with open(data_file.with_suffix(".pickle"), "wb") as handle:
                pickle.dump(all_results_data, handle)
                print(f"Results saved to {handle.name}")

            # print(result_df)
            # error_rates = result_df.sum(axis=0) / result_df.count(axis=0)
            # error_rates_fig = error_rates.sort_values().plot(kind='bar')
            # error_rates_title = 'Error rates: rows={dir_name.parent.name}, cols={agent_type}'
            # plt.savefig(dataset_dir / error_rates_title)


    """ 
    Takes in the outcome matrix of the dermatology data in (user x model) format -- wide not long -- and returns dict of form
    {agent_type: {skin_tone: outcome_matrix}}
    Parameters: 
        - df (pd.dataframe): Dataframe where rows are users, columns are models. 
        - col_split (dict): Should be a dict of form {str: function} where the function takes in a dataframe and returns a subset of its columns (prediction agents).
            Most useful splits have been initialized in init() and can be accessed as class methods. For example, splitting by humans vs models.
        - row_split (dict): Similar to col_split, but the function should return a subset of the rows (instances). 
            For example, returns a subset of rows based on race of input instance.
        - include_ref (bool): If true, includes a reference dataframe for each split. The current reference is 
            outcomes under 'independence'. The reference_df can occasionally have rounding errors; this usually doesn't matter
            but if df has very few rows then rounding can materially affect the proportions shown during plotting.
        - split_dicts: Determines the return form of the result_dict. 
            If true, the return form will be {col_name : {row_name: {"Expected": expected_df, "Observed": observed_df}}.
            If false, the return form will be {col_name : {row_name_'reference': expected_df, row_name: observed_df}}
    Returns: 
        - results_dict (dict): Nested dict. Outer dict key is name of agent type (see self.agent_split). Value is another dict
        of form {skin_tone: outcome_matrix}
    """    
    def split_by_metadata(self, ddi_df, col_split, row_split, include_ref=True, split_dicts=False):
        if split_dicts:
            results_dict = defaultdict(lambda: defaultdict(dict))
        else:
            results_dict = defaultdict(dict)
        for row_name, row_split_fn in row_split.items(): 
            # split_results_dict= defaultdict(dict)
            row_filtered_df = row_split_fn(ddi_df)
            for col_name, col_split_fn in col_split.items(): 
                col_filtered_df = col_split_fn(row_filtered_df)
                # results_dict[col_name][row_name] = col_filtered_df #eg results_dict['models'] = {12: ..., 34: ...}
                if split_dicts:
                    results_dict[col_name][row_name]["Observed"] = col_filtered_df
                else:
                    results_dict[col_name][row_name] = col_filtered_df

                print(f"Splitting by {row_name} and {col_name}")
                if include_ref:
                    reference_df = generate_independent_df(col_filtered_df)
                    if split_dicts:
                        results_dict[col_name][row_name]["Expected"] = reference_df
                    else:
                        results_dict[col_name][f"{row_name}_reference"] = reference_df

                    

        return results_dict
    
    
    """ This function will take the raw derm_data matrix as formatted in original paper reformats in format needed for homogenization code
        derm_data comes in 'long' format where each (user, model) prediction and ground truth has its own row. This reformats
        the data into 'outcome matrix' of shape (num_users, models_queried).  Additionally, transform predictions into successes/failures
        and calculate the degree of agreement in predictions within models and within humans  """

    def format_DDI_data(self, derm_data):
        derm_data["is_correct"] = derm_data.apply(
            lambda x: x["algorithm_output"] == x["ground_truth"], axis=1
        )
        derm_data["is_failure"] = derm_data["is_correct"].replace(
            {True: 0, False: 1}
        )  # converts True / False to 1 / 0 because we represent failures as 1s

        derm_data_formatted = derm_data.pivot(
            index=["DDI_ID", "patient_id", "skin_tone", "ground_truth"],
            columns="algorithm_name",
            values=["is_failure", "algorithm_output"],
        )

        # These columns contain information on models failures/successes. This is what will be returned because we failures/successes
        # to calculate model homogenization
        derm_data_failures = derm_data_formatted["is_failure"]

        # These columns contains predictions, not whether the prediction is correct. Used to calculate agreement between models.
        derm_data_predictions = derm_data_formatted["algorithm_output"]

        # Will calculate human and model accuracy and agreement
        for agent_type, column_names in self.column_name_map.items():
            agent_error = derm_data_failures[column_names].mean(axis=1)
            error_key = f"{agent_type}_error"
            derm_data_failures[error_key] = agent_error
            print(f"agent type: {agent_type}, column names: {column_names}")
            # We add a function to self.accuracy_split so we can correctly identify these columns elsewhere in the code 
            for accuracy_val in sorted(agent_error.unique()):
                accuracy_name = self.Notation.by_human_accuracy_semantic_dict.get(accuracy_val, accuracy_val)
                self.accuracy_split[agent_type][accuracy_name] = lambda df, key=error_key, val=accuracy_val: df[df[key] == val]
                self.accuracy_split[agent_type][f'{accuracy_name} random sample'] = lambda df, agent_type=agent_type, accuracy_name=accuracy_name: df.sample(n=self.accuracy_split[agent_type][accuracy_name](df).shape[0])

            agent_confidence = derm_data_predictions[column_names].mean(axis=1)
            confidence_key = f"{agent_type}_agreement"
            derm_data_failures[confidence_key] = agent_confidence
            for confidence_val in sorted(agent_confidence.unique()):
                confidence_name = self.Notation.by_human_confidence_semantic_dict.get(confidence_val, confidence_val)
                self.confidence_split[agent_type][confidence_name] = lambda df, key=confidence_key, val=confidence_val: df[df[key] == val]

        # We only return failures because we use failures to measure homogenization
        return derm_data_failures

    """
    Splits data by race, ground truth, human agreement, and human error; then plots and saves plotes to subdirectories
    within provided dir_path directory.
    Parameters:
    - df: dataframe containing all dermatology data 
    - dir_path: path to directory where plots will be saved
    - include_ref: whether to include a reference distribution in plots
    """

    def run_metadata_analysis(self, df, dir_path, include_ref=True):
        if include_ref:
            palette = "tab20"
        else:
            palette = "Set1"

        split_dicts = True
        
        alL_rows = self.split_by_metadata(df, col_split=self.agent_split, row_split={'all_rows' : lambda df: df}, split_dicts=split_dicts, include_ref=include_ref)
        no_split_dir = dir_path / 'no_split'
        self.plot_results(alL_rows, no_split_dir, palette=palette, orientation='vertical', dict_is_split=split_dicts)

        by_race_split = self.split_by_metadata(df, col_split=self.agent_split, row_split=self.skin_tone_split, split_dicts=split_dicts, include_ref=include_ref)
        by_race_dir =  dir_path / 'by_race'
        self.plot_results(by_race_split, by_race_dir, dict_is_split=split_dicts, palette=palette)

        try:
            by_ground_truth_split = self.split_by_metadata(
                df,
                col_split=self.agent_split,
                row_split=self.ground_truth_split,
                include_ref=include_ref,
            )
            by_ground_truth_dir = dir_path / "by_ground_truth"
            self.plot_results(
                by_ground_truth_split, by_ground_truth_dir, palette=palette
            )
        except KeyError:
            print("Already have filtered on ground truth")

        models_split = {
            "models": self.agent_split["models"]
        }  # only want to include models since we condition on human accuracy,confidence
        by_confidence_dir = dir_path / "by_confidence"
        by_confidence_split = self.split_by_metadata(
            df,
            col_split=models_split,
            row_split=self.confidence_split["humans"],
            include_ref=include_ref,
        )
        self.plot_results(by_confidence_split, by_confidence_dir, palette=palette)

        by_human_error_dir = dir_path / 'by_human_accuracy'
        by_human_error_split = self.split_by_metadata(df, col_split=models_split, row_split=self.accuracy_split['humans'], split_dicts=split_dicts, include_ref=include_ref)
        self.plot_results(by_human_error_split, by_human_error_dir, dict_is_split=split_dicts, palette=palette)

        humans_split = {"humans": self.agent_split["humans"]}
        by_model_error_dir = dir_path / "by_model_accuracy"
        by_model_error_split = self.split_by_metadata(
            df,
            col_split=humans_split,
            row_split=self.accuracy_split["models"],
            include_ref=include_ref,
        )
        self.plot_results(by_model_error_split, by_model_error_dir, palette=palette)


def main(dir_name, include_ham10k=False):
    dir_path = Path(dir_name)
    print(include_ham10k)
    ddi_analysis = DDIAnalysis(include_ham10k=include_ham10k)
    print(ddi_analysis.column_name_map)

    outcome_df = ddi_analysis.data_loader()

    malignancy_df = ddi_analysis.ground_truth_split["malignancy"](outcome_df)
    benign_df = ddi_analysis.ground_truth_split["no malignancy"](outcome_df)

    all_path = dir_path / "all"
    ddi_analysis.run_metadata_analysis(outcome_df, all_path)

    malignancy_path = dir_path / "malignancy"
    ddi_analysis.run_metadata_analysis(malignancy_df, malignancy_path)

    benign_path = dir_path / "benign"
    ddi_analysis.run_metadata_analysis(benign_df, benign_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dir_name", help="name to be used for naming results folders")
    parser.add_argument(
        "--ham10k", action="store_true"
    )  # if true, will include ham10k in models. Ham10k has recall near 0, so I exclude it often.
    args = parser.parse_args()

    main(args.dir_name, args.ham10k)
