# -*- coding: utf-8 -*-
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.lines import Line2D


def get_directories():
    """Get the required directories for input and output files."""

    cur_filepath = Path(__file__).resolve()
    root_dir = cur_filepath.parent.parent
    exp_dir = root_dir / 'experiments'
    summary_dir = exp_dir / 'summary'
    output_dir = root_dir / 'plots'

    output_dir.mkdir(exist_ok=True)
    return summary_dir, output_dir


def take_first_4_chars_or_2_plus_2_chars(s):
    # if space not in string, take first 4 characters
    # else take first 2 characters before space and first 2 characters after space, add a dot in between
    # capitalize all characters
    if s == "Pneumonia":
        return "PN.IA"
    elif s == "Pneumothorax":
        return "PN.OR"
    if ' ' in s:
        s = s[:2].capitalize() + '.' + s[s.index(' ') + 1:s.index(' ') + 3].capitalize()
    else:
        s = s[:4].capitalize()
    return s.upper()


def load_and_prepare_data(summary_dir):
    """Load the data and prepare it by removing the 'mean' index."""

    def load_csv(filename):
        return pd.read_csv(summary_dir / filename, index_col=0).drop(index="mean")

    filenames = ['BioViL_seen_auc.csv', 'MedKLIP_seen_auc.csv', 'KAD_seen_auc.csv']
    dfs = [load_csv(filename) for filename in filenames]

    # Capitalise row index for each dataframe
    for df in dfs:
        df.index = df.index.str.capitalize()

    return dfs


def calculate_average_performance(df, exclude_columns=['Baseline']):
    """Calculate average vs. baseline performance for seen data."""
    other_styles_mean = df.drop(columns=exclude_columns).mean(axis=1)
    baseline = df['Baseline']
    return other_styles_mean, baseline


def calculate_percentage_diff_from_baseline(df):
    """Calculate percentage difference from baseline for seen data."""
    percentage_diff = (df.subtract(df["Baseline"], axis=0).divide(df["Baseline"], axis=0)) * 100
    return percentage_diff.drop(columns=["Baseline"])


def rename_prompt_styles(df):
    """Rename the prompt style columns in a DataFrame."""
    prompt_style_mapping_dict = {"Disease Name": "Name", "Disease Symptom": "Symptom",
                                 "Disease Attribute": "Attribute",
                                 "Disease Description Plain English": "Plain ENG",
                                 "Disease Description Medical Style": "MED Style",
                                 "Disease Description Radiologist Style": "RAD Style",
                                 "Baseline": "Baseline"}
    return df.rename(columns=prompt_style_mapping_dict)


def create_scatter_plot(biovil_data, medklip_data, kad_data, disease_classes, output_dir):
    """Create and save a scatter plot comparing models' performances."""

    fig, ax = plt.subplots(figsize=(12, 8))  # Maintain figure size for a larger plot
    colors = plt.get_cmap('tab20', len(disease_classes))

    for i, disease_class in enumerate(disease_classes):
        for data, marker, label in zip([biovil_data, medklip_data, kad_data], ['o', 's', '^'],
                                       ['BioViL', 'MedKLIP', 'KAD']):
            ax.scatter(data['baseline'].iloc[i], data['avg_vs_baseline'].iloc[i], color=colors(i),
                       marker=marker, s=175, label=label if i == 0 else "")

    # Customize plot
    ax.set_title('AUC of Original Prompt VS. Averaged Six Style Prompt on Seen Diseases', fontsize=19,
                 weight='bold', ha='left', x=0)  # Increased fontsize and added weight
    ax.set_xlabel('AUC of Original Prompt', fontsize=20, weight='bold')  # Increased fontsize
    ax.set_ylabel('Average AUC of Six Style Prompt', fontsize=20,
                  weight='bold')  # Increased fontsize
    ax.plot([0.4, 1], [0.4, 1], 'k--', lw=2)

    # Set x-axis and y-axis tick labels to be bold
    ax.tick_params(axis='both', which='major', labelsize=14, width=3)
    plt.setp(ax.get_xticklabels(), weight='bold')
    plt.setp(ax.get_yticklabels(), weight='bold')

    # Increase border thickness of the plot
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.spines['top'].set_linewidth(2)
    ax.spines['right'].set_linewidth(2)

    # Set limits to be equal on both axes and ensure the plot is square
    ax.set_xlim(0.4, 1)
    ax.set_ylim(0.4, 1)
    ax.set_aspect('equal', adjustable='box')  # Ensure the aspect ratio is equal

    # Legend setup
    legend_elements = [
        Line2D([0], [0], marker='o', color='black', markerfacecolor='none', markersize=10,
               label='BioViL'),
        Line2D([0], [0], marker='s', color='black', markerfacecolor='none', markersize=10,
               label='MedKLIP'),
        Line2D([0], [0], marker='^', color='black', markerfacecolor='none', markersize=10,
               label='KAD')]
    legend_elements += [Line2D([0], [0], color=colors(i), lw=4, label=disease_class) for
                        i, disease_class in enumerate(disease_classes)]

    # Adjust legend to reduce empty space on the right side
    legend = ax.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1.05, 0.5),
                       title="Models & Disease Classes", fontsize=18, title_fontsize=18,
                       frameon=True, prop={'size': 18}, edgecolor='black')

    # Increase border thickness of the legend box
    legend.get_frame().set_linewidth(2)  # Set the border thickness

    plt.tight_layout(rect=(0, 0, 0.9, 1))
    plt.savefig(output_dir / 'seen_classes_non_baseline_vs_baseline_auc.png')
    plt.show()


def create_heatmaps(biovil_data, medklip_data, kad_data, output_dir):
    """Create and save heatmaps comparing models' performances."""

    def highlight_max_in_row(ax, data):
        """Highlight the maximum value in each row of the heatmap."""
        for i in range(data.shape[0]):
            j = data.iloc[i].idxmax()
            j_index = data.columns.get_loc(j)
            ax.add_patch(plt.Rectangle((j_index, i), 1, 1, fill=False, edgecolor='black', lw=3))

    prompt_style_mapping_dict = {"Disease Name": "Name", "Disease Symptom": "Symptom",
                                 "Disease Attribute": "Attribute",
                                 "Disease Description Plain English": "Plain ENG",
                                 "Disease Description Medical Style": "MED Style",
                                 "Disease Description Radiologist Style": "RAD Style",
                                 "Baseline": "Baseline"}
    prompt_style_order = ["Baseline", "Name", "Symptom", "Attribute", "Plain ENG", "MED Style",
                          "RAD Style"]

    # Prepare data
    biovil_data = prepare_data(biovil_data, prompt_style_mapping_dict, prompt_style_order)
    medklip_data = prepare_data(medklip_data, prompt_style_mapping_dict, prompt_style_order)
    kad_data = prepare_data(kad_data, prompt_style_mapping_dict, prompt_style_order)

    # Determine width of each heatmap
    widths = [len(data.columns) for data in [biovil_data, medklip_data, kad_data]]

    # Determine common color scale
    vmin = min(biovil_data.min().min(), medklip_data.min().min(), kad_data.min().min())
    vmax = max(biovil_data.max().max(), medklip_data.max().max(), kad_data.max().max())

    fig, axes = plt.subplots(1, 3, figsize=(18, 9),
                             gridspec_kw={'width_ratios': widths, 'wspace': 0.06})

    for ax, data, title in zip(axes, [biovil_data, medklip_data, kad_data],
                               ["BioViL", "MedKLIP", "KAD"]):
        # Find max value in each row to highlight with pandas
        data_arr = data.to_numpy()
        max_val_mask = data_arr == np.max(data_arr, axis=1)[:, None]

        # Add in highlighted max values to heatmap
        sns.heatmap(data, mask=~max_val_mask, annot=True, fmt=".2f", cmap="coolwarm", ax=ax,
                    vmin=vmin, vmax=vmax, linewidths=0, linecolor='black', cbar=False,
                    annot_kws={"size": 15, "style": "italic"})
        # Second layer of heatmap to add non-highlighted values
        sns.heatmap(data, mask=max_val_mask, annot=True, fmt=".2f", cmap="coolwarm", ax=ax,
                    vmin=vmin, vmax=vmax, linewidths=0, linecolor='black', cbar=False,
                    annot_kws={"size": 15})

        highlight_max_in_row(ax, data)
        ax.set_title(f"{title} AUC on Seen Classes", fontsize=16, weight='bold')
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center', fontsize=16,
                           weight='bold')
        ax.tick_params(axis='x', labelsize=16, labelrotation=45, width=3,
                       length=5)  # Set x-axis tick label size and weight
        ax.tick_params(axis='y', labelsize=16, width=3,
                       length=5)  # Set y-axis tick label size and weight
        if ax != axes[0]:
            ax.set_ylabel('')
            ax.set_yticks([])
        else:
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16, weight='bold')
        ax.spines['bottom'].set_linewidth(0)
        ax.spines['left'].set_linewidth(0)
        ax.spines['top'].set_linewidth(0)
        ax.spines['right'].set_linewidth(0)
        for _, spine in ax.spines.items():
            spine.set_visible(True)

    # plt.tight_layout()
    # Add more space at the bottom to fit the labels
    plt.subplots_adjust(bottom=0.16, left=0.15)
    plt.savefig(output_dir / 'seen_classes_auc_heatmaps.png')
    plt.show()


def create_percentage_diff_barplot(biovil_diff, medklip_diff, kad_diff, output_dir):
    """Create a single figure with vertically stacked bar plots for percentage difference from baseline for three models."""

    # Rename prompt styles in the data
    biovil_diff = rename_prompt_styles(biovil_diff)
    medklip_diff = rename_prompt_styles(medklip_diff)
    kad_diff = rename_prompt_styles(kad_diff)

    # Determine common Y-axis limits across all models
    ymin = min(biovil_diff.min().min(), medklip_diff.min().min(), kad_diff.min().min())
    ymax = max(biovil_diff.max().max(), medklip_diff.max().max(), kad_diff.max().max())

    # Create a figure with 3 vertically stacked subplots
    fig, axes = plt.subplots(3, 1, figsize=(14, 9))  # Adjusted height for three plots

    for ax, diff_data, model_name in zip(axes, [biovil_diff, medklip_diff, kad_diff],
                                         ["BioViL", "MedKLIP", "KAD"]):
        # Use colormap
        colormap = plt.get_cmap("tab20")
        colors = [colormap(i) for i in range(len(diff_data.columns))]

        # Create bar plot
        diff_data.plot(kind='bar', width=0.85, color=colors, ax=ax)

        # Add vertical dashed lines between each group
        for i in range(1, len(diff_data.index)):
            ax.axvline(x=i - 0.5, color='gray', linestyle='--', linewidth=1)

        # Add horizontal line at y=0
        ax.axhline(y=0, color='black', linestyle='-', linewidth=1)

        # Set labels and title
        if ax == axes[1]:
            ax.set_ylabel("Percentage Difference (%)", fontsize=18, weight='bold')

        ax.set_title(f"{model_name} Percentage Difference from Baseline on Seen Classes",
                     fontsize=18, weight='bold')
        ax.set_ylim([ymin, ymax])

        # Remove x-axis labels for the first and second plots
        if ax != axes[-1]:
            ax.set_xticklabels([])
        else:
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center', fontsize=16,
                               weight='bold')
            ax.tick_params(axis='x', rotation=45, labelsize=16, width=3)
        # Add legend
        ax.legend(title="Prompt Style", title_fontsize=16, fontsize=16, loc='center left',
                  bbox_to_anchor=(1, 0.5), prop={'size': 14}, edgecolor='black')

        # Increase y-axis tick label size and weight
        ax.tick_params(axis='y', labelsize=16, width=3)

        # Increase border thickness of the legend box
        ax.get_legend().get_frame().set_linewidth(2)  # Set the border thickness

        # Increase border thickness of the plot
        ax.spines['bottom'].set_linewidth(2)
        ax.spines['left'].set_linewidth(2)
        ax.spines['top'].set_linewidth(2)
        ax.spines['right'].set_linewidth(2)

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_dir / "combined_seen_percentage_diff_bar.png")
    plt.show()


def prepare_data(data, prompt_style_mapping_dict, prompt_style_order):
    """Prepare data by renaming columns and reordering them based on provided order."""
    data = data.rename(columns=prompt_style_mapping_dict)
    available_columns = [col for col in prompt_style_order if col in data.columns]
    return data[available_columns]


if __name__ == "__main__":
    summary_dir, output_dir = get_directories()
    biovil_seen_auc, medklip_seen_auc, kad_seen_auc = load_and_prepare_data(summary_dir)

    # Extract the disease classes (index) once
    disease_classes = biovil_seen_auc.index

    # Drop 'Disease Name' column from kad_seen_auc, as it's the same with Baseline
    kad_seen_auc = kad_seen_auc.drop(columns='Disease Name')

    # Prepare data for scatter plot
    biovil_data = {'avg_vs_baseline': calculate_average_performance(biovil_seen_auc)[0],
                   'baseline': calculate_average_performance(biovil_seen_auc)[1]}
    medklip_data = {'avg_vs_baseline': calculate_average_performance(medklip_seen_auc)[0],
                    'baseline': calculate_average_performance(medklip_seen_auc)[1]}
    kad_data = {'avg_vs_baseline': calculate_average_performance(kad_seen_auc)[0],
                'baseline': calculate_average_performance(kad_seen_auc)[1]}

    create_scatter_plot(biovil_data, medklip_data, kad_data, disease_classes, output_dir)
    create_heatmaps(biovil_seen_auc, medklip_seen_auc, kad_seen_auc, output_dir)

    # Calculate and plot percentage difference from baseline for each model
    biovil_percentage_diff = calculate_percentage_diff_from_baseline(biovil_seen_auc)
    medklip_percentage_diff = calculate_percentage_diff_from_baseline(medklip_seen_auc)
    kad_percentage_diff = calculate_percentage_diff_from_baseline(kad_seen_auc)

    # Calculate average AUC across all diseases for BioViL
    biovil_avg_auc = biovil_seen_auc.mean()
    print(f"BioViL average AUC for each prompt style:\n{biovil_avg_auc}")

    # Calculate average AUC across all diseases for MedKLIP
    medklip_avg_auc = medklip_seen_auc.mean()
    print(f"MedKLIP average AUC for each prompt style:\n{medklip_avg_auc}")

    # Calculate average AUC across all diseases for KAD
    kad_avg_auc = kad_seen_auc.mean()
    print(f"KAD average AUC for each prompt style:\n{kad_avg_auc}")

    # Calculate average percentage diff across all models and all diseases
    avg_percentage_diff = (
                (biovil_percentage_diff + medklip_percentage_diff + kad_percentage_diff) / 3).mean()
    print(
        f"Average percentage difference across all models and diseases: {avg_percentage_diff.mean():.2f}%")

    # Calculate average percentage diff for BioViL of each prompt style
    biovil_avg_percentage_diff = biovil_percentage_diff.mean()
    print(
        f"BioViL average percentage difference for each prompt style:\n{biovil_avg_percentage_diff}")

    # Calculate average percentage diff for MedKLIP of each prompt style
    medklip_avg_percentage_diff = medklip_percentage_diff.mean()
    print(
        f"MedKLIP average percentage difference for each prompt style:\n{medklip_avg_percentage_diff}")

    # Calculate average percentage diff for KAD of each prompt style
    kad_avg_percentage_diff = kad_percentage_diff.mean()
    print(f"KAD average percentage difference for each prompt style:\n{kad_avg_percentage_diff}")

    create_percentage_diff_barplot(biovil_percentage_diff, medklip_percentage_diff,
                                   kad_percentage_diff, output_dir)
