# -*- coding: utf-8 -*-
import os
from pathlib import Path

import pandas as pd


def enumerate_files(path: str) -> tuple[list, list]:
    output_filepaths = []
    output_filenames = []
    for root, folder, filenames in os.walk(path):
        for filename in filenames:
            output_filepaths.append(Path(root) / filename)
            output_filenames.append(filename)
    return output_filepaths, output_filenames


if __name__ == "__main__":
    # find current file path
    cur_filepath = Path(__file__).resolve()
    cur_dir = cur_filepath.parent
    root_dir = cur_dir.parent
    # enumerate all files under experiments
    exp_dir = root_dir / 'experiments'
    # output directory
    output_dir = exp_dir / 'summary'

    if not output_dir.exists():
        output_dir.mkdir()

    prompt_types = ["Disease Name", "Disease Symptom", "Disease Attribute",
                    "Disease Description Plain English", "Disease Description Medical Style",
                    "Disease Description Radiologist Style", "Baseline"]
    interpretability_dict = {"Disease Name": 3, "Disease Symptom": 7, "Disease Attribute": 9,
                             "Disease Description Plain English": 8,
                             "Disease Description Medical Style": 10,
                             "Disease Description Radiologist Style": 9}
    disease_convert_dict = {"nodule": "lung nodule", "mass": "lung mass",
                            "thicken": "pleural thicken", "effusion": "pleural effusion",
                            "tail_abnorm_obs": "fibrosis", "covid": "covid-19",
                            "infiltrate": "infiltration"}
    seen_classes = ["atelectasis", "cardiomegaly", "pleural effusion", "infiltration", "lung mass",
                    "lung nodule", "pneumonia", "pneumothorax", "consolidation", "edema",
                    "emphysema", "fibrosis", "pleural thicken", "hernia"]
    unseen_classes = ["covid-19"]

    seen_auc_dict = {}
    unseen_auc_dict = {}

    dataset_auc_dict = {}
    dataset_f1_dict = {}
    dataset_acc_dict = {}

    df_dict = {}
    baseline_dict = {}

    exp_filepaths, exp_filenames = enumerate_files(exp_dir)
    for filepath, filename in zip(exp_filepaths, exp_filenames):
        # if .csv file
        if filename.endswith(".csv"):
            file_prompt_type = ""
            for prompt_type in prompt_types:
                # convert to lowercase and replace spaces with underscores
                if prompt_type.lower().replace(" ", "_") + "_result" in filename.lower():
                    file_prompt_type = prompt_type
                    break
            # if not empty
            if not file_prompt_type:
                continue

            # Get path after experiments
            relative_path = filepath.relative_to(exp_dir)
            # breakdown the path
            parts = relative_path.parts

            model_name = parts[0]
            dataset_name = parts[1]

            # Skip BioViL_T
            if model_name == "BioViL_T":
                continue

            df = pd.read_csv(filepath)
            # Use first column as index
            df = df.set_index(df.columns[0])

            # Save dataset AUC
            if model_name not in dataset_auc_dict:
                dataset_auc_dict[model_name] = {}
            if dataset_name not in dataset_auc_dict[model_name]:
                dataset_auc_dict[model_name][dataset_name] = {}
            if file_prompt_type not in dataset_auc_dict[model_name][dataset_name]:
                dataset_auc_dict[model_name][dataset_name][file_prompt_type] = {}
            if model_name not in dataset_f1_dict:
                dataset_f1_dict[model_name] = {}
            if dataset_name not in dataset_f1_dict[model_name]:
                dataset_f1_dict[model_name][dataset_name] = {}
            if file_prompt_type not in dataset_f1_dict[model_name][dataset_name]:
                dataset_f1_dict[model_name][dataset_name][file_prompt_type] = {}
            if model_name not in dataset_acc_dict:
                dataset_acc_dict[model_name] = {}
            if dataset_name not in dataset_acc_dict[model_name]:
                dataset_acc_dict[model_name][dataset_name] = {}
            if file_prompt_type not in dataset_acc_dict[model_name][dataset_name]:
                dataset_acc_dict[model_name][dataset_name][file_prompt_type] = {}

            # Lower case column names
            df.columns = df.columns.str.lower()
            for col in df:
                if col == "mean" or col == "metric":
                    continue
                if col in disease_convert_dict:
                    col = disease_convert_dict[col]
                    # Rename the column
                    df = df.rename(columns={df.columns[0]: col})
                if col not in seen_classes and col not in unseen_classes:
                    print(f"Unknown class: {col}")

                if col in seen_classes:
                    if model_name not in seen_auc_dict:
                        seen_auc_dict[model_name] = {}
                    if file_prompt_type not in seen_auc_dict[model_name]:
                        seen_auc_dict[model_name][file_prompt_type] = {}
                    if col not in seen_auc_dict[model_name][file_prompt_type]:
                        seen_auc_dict[model_name][file_prompt_type][col] = df[col]["AUC"]

                if col in unseen_classes:
                    if model_name not in unseen_auc_dict:
                        unseen_auc_dict[model_name] = {}
                    if file_prompt_type not in unseen_auc_dict[model_name]:
                        unseen_auc_dict[model_name][file_prompt_type] = {}
                    if col not in unseen_auc_dict[model_name][file_prompt_type]:
                        unseen_auc_dict[model_name][file_prompt_type][col] = df[col]["AUC"]

                dataset_auc_dict[model_name][dataset_name][file_prompt_type][col] = df[col]["AUC"]
                # Either F1 or F1s
                if "F1" in df[col]:
                    dataset_f1_dict[model_name][dataset_name][file_prompt_type][col] = df[col]["F1"]
                else:
                    dataset_f1_dict[model_name][dataset_name][file_prompt_type][col] = df[col]["F1s"]
                if "ACC" in df[col]:
                    dataset_acc_dict[model_name][dataset_name][file_prompt_type][col] = df[col]["ACC"]
                elif "ACCs" in df[col]:
                    dataset_acc_dict[model_name][dataset_name][file_prompt_type][col] = df[col]["ACCs"]
                else:
                    dataset_acc_dict[model_name][dataset_name][file_prompt_type][col] = df[col]["Accs"]

    seen_auc_dfs = {}

    for model_name, auc_dict in seen_auc_dict.items():
        # Convert to DataFrame
        df = pd.DataFrame(auc_dict)
        # Add mean row for each column
        df.loc["mean"] = df.mean()
        # Sort columns according to prompt_types
        df = df[prompt_types]
        # Save to csv
        df.to_csv(output_dir / (str(model_name) + "_seen_auc.csv"))
        seen_auc_dfs[model_name] = df

    unseen_auc_dfs = {}

    for model_name, auc_dict in unseen_auc_dict.items():
        # Convert to DataFrame
        df = pd.DataFrame(auc_dict)
        # Add mean row for each column
        df.loc["mean"] = df.mean()
        # Sort columns according to prompt_types
        df = df[prompt_types]
        # Save to csv
        df.to_csv(output_dir / (str(model_name) + "_unseen_auc.csv"))
        unseen_auc_dfs[model_name] = df

    dataset_auc_dfs = {}
    for model_name, dataset_dict in dataset_auc_dict.items():
        for dataset_name, auc_dict in dataset_dict.items():
            # Convert to DataFrame
            df = pd.DataFrame(auc_dict)
            # Add mean row for each column
            df.loc["mean"] = df.mean()
            # Sort columns according to prompt_types
            df = df[prompt_types]
            # Save to csv
            df.to_csv(output_dir / (str(model_name) + "_" + str(dataset_name) + "_auc.csv"))
            dataset_auc_dfs[model_name + "_" + dataset_name] = df

    dataset_f1_dfs = {}
    for model_name, dataset_dict in dataset_f1_dict.items():
        for dataset_name, f1_dict in dataset_dict.items():
            # Convert to DataFrame
            df = pd.DataFrame(f1_dict)
            # Add mean row for each column
            df.loc["mean"] = df.mean()
            # Sort columns according to prompt_types
            df = df[prompt_types]
            # Save to csv
            df.to_csv(output_dir / (str(model_name) + "_" + str(dataset_name) + "_f1.csv"))
            dataset_f1_dfs[model_name + "_" + dataset_name] = df

    dataset_acc_dfs = {}
    for model_name, dataset_dict in dataset_acc_dict.items():
        for dataset_name, acc_dict in dataset_dict.items():
            # Convert to DataFrame
            df = pd.DataFrame(acc_dict)
            # Add mean row for each column
            df.loc["mean"] = df.mean()
            # Sort columns according to prompt_types
            df = df[prompt_types]
            # Save to csv
            df.to_csv(output_dir / (str(model_name) + "_" + str(dataset_name) + "_acc.csv"))
            dataset_acc_dfs[model_name + "_" + dataset_name] = df

    # Convert interpretability dict to a df
    interpretability_df = pd.DataFrame(interpretability_dict, index=["interpretability"])
    interpretability_df.to_csv(output_dir / "interpretability.csv")
