# -*- coding: utf-8 -*-
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def get_directories():
    """Get the required directories for input and output files."""

    cur_filepath = Path(__file__).resolve()
    root_dir = cur_filepath.parent.parent
    exp_dir = root_dir / 'experiments'
    summary_dir = exp_dir / 'summary'
    output_dir = root_dir / 'plots'

    output_dir.mkdir(exist_ok=True)
    return summary_dir, output_dir


def load_and_prepare_data(summary_dir):
    """Load the data and prepare it by removing the 'mean' index."""

    def load_csv(filename):
        df = pd.read_csv(summary_dir / filename, index_col=0)
        if "mean" in df.index:
            return df.drop(index="mean")
        return df

    filenames = ['BioViL_unseen_auc.csv', 'MedKLIP_unseen_auc.csv', 'KAD_unseen_auc.csv',
                 'interpretability.csv']
    return [load_csv(filename) for filename in filenames]


def rename_prompt_styles(df):
    """Rename the prompt style columns in a DataFrame."""
    prompt_style_mapping_dict = {"Disease Name": "Name", "Disease Symptom": "Symptom",
                                 "Disease Attribute": "Attribute",
                                 "Disease Description Plain English": "Plain ENG",
                                 "Disease Description Medical Style": "MED Style",
                                 "Disease Description Radiologist Style": "RAD Style",
                                 "Baseline": "Baseline"}
    return df.rename(columns=prompt_style_mapping_dict)


def create_stacked_bar_line_plots(biovil_unseen_auc, medklip_unseen_auc, kad_unseen_auc,
                                  interpretability_df, output_dir):
    """Create a stacked figure with bar and line plots for each disease, arranged vertically."""
    # Rename prompt styles in the data
    biovil_unseen_auc = rename_prompt_styles(biovil_unseen_auc)
    medklip_unseen_auc = rename_prompt_styles(medklip_unseen_auc)
    kad_unseen_auc = rename_prompt_styles(kad_unseen_auc)

    # Drop Baseline style from the dataframes
    biovil_unseen_auc = biovil_unseen_auc.drop(columns='Baseline')
    medklip_unseen_auc = medklip_unseen_auc.drop(columns='Baseline')
    kad_unseen_auc = kad_unseen_auc.drop(columns='Baseline')

    # Extract interpretability scores (assuming single row with interpretability)
    interpretability_scores = interpretability_df.iloc[0].values

    # Determine the shared AUC scale limits across all diseases
    min_auc = min(biovil_unseen_auc.min().min(), medklip_unseen_auc.min().min(),
                  kad_unseen_auc.min().min())
    max_auc = max(biovil_unseen_auc.max().max(), medklip_unseen_auc.max().max(),
                  kad_unseen_auc.max().max())

    if len(biovil_unseen_auc.index) > 1:
        # Create a figure with two subplots stacked vertically
        fig, axes = plt.subplots(len(biovil_unseen_auc.index), 1, figsize=(12, 10), sharex=True)

        # Loop over each disease class to create individual plots
        for i, disease in enumerate(biovil_unseen_auc.index):
            ax1 = axes[i]

            # Extract AUC values for the current disease
            biovil_values = biovil_unseen_auc.loc[disease].values.flatten()
            medklip_values = medklip_unseen_auc.loc[disease].values.flatten()
            kad_values = kad_unseen_auc.loc[disease].values.flatten()

            # X-axis positions
            x = np.arange(len(biovil_unseen_auc.columns))

            # Width of the bars
            width = 0.2

            # Plotting the bars for AUC in the specified order
            ax1.bar(x - width, biovil_values, width, label='BioViL AUC', color='skyblue')
            ax1.bar(x, medklip_values, width, label='MedKLIP AUC', color='lightcoral')
            ax1.bar(x + width, kad_values, width, label='KAD AUC', color='lightgreen')

            if disease == "covid-19":
                disease = "COVID-19"
            else:
                disease = disease.capitalize()

            # Setting labels and title
            ax1.set_ylabel('AUC')
            ax1.set_title(f'AUC and Interpretability Scores for {disease}')
            ax1.set_ylim(min_auc - 0.05, max_auc + 0.05)
            ax1.legend(loc='upper left')

            # Create a second y-axis for the interpretability line plot
            ax2 = ax1.twinx()
            ax2.plot(x, interpretability_scores, color='darkblue', marker='o',
                     label='Interpretability')
            ax2.set_ylabel('Interpretability')

            # Set the interpretability scale from 1 to 10
            ax2.set_ylim(0, 11.9)
            ax2.legend(loc='upper right')

            # Remove x-axis labels from the top plot
            if i == 0:
                ax1.set_xticklabels([])

        # Set x-axis labels only for the bottom plot
        axes[-1].set_xlabel('Prompt Styles')
        axes[-1].set_xticks(x)
        axes[-1].set_xticklabels(biovil_unseen_auc.columns, rotation=45, ha="right")
    else:
        # Create a figure with a single subplot
        fig, ax1 = plt.subplots(figsize=(12, 6))

        # Extract AUC values for the single disease
        biovil_values = biovil_unseen_auc.values.flatten()
        medklip_values = medklip_unseen_auc.values.flatten()
        kad_values = kad_unseen_auc.values.flatten()

        # X-axis positions
        x = np.arange(len(biovil_unseen_auc.columns))

        # Width of the bars
        width = 0.2

        # Plotting the bars for AUC in the specified order
        ax1.bar(x - width, biovil_values, width, label='BioViL AUC', color='skyblue')
        ax1.bar(x, medklip_values, width, label='MedKLIP AUC', color='lightcoral')
        ax1.bar(x + width, kad_values, width, label='KAD AUC', color='lightgreen')

        disease = biovil_unseen_auc.index[0]
        if disease == "covid-19":
            disease = "COVID-19"
        else:
            disease = disease.capitalize()

        # Setting labels and title
        ax1.set_ylabel('AUC', fontsize=16, weight='bold')
        ax1.set_title(f'AUC and Interpretability Scores on Unseen Classes({disease})', fontsize=16, weight='bold')
        ax1.set_ylim(0.45, 0.65)
        ax1.legend(loc='upper left', fontsize=14)
        yticks = np.linspace(ax1.get_ylim()[0], ax1.get_ylim()[1], 5)
        ax1.set_yticks(yticks)
        ax1.tick_params(axis='y', labelsize=14, width=3)
        plt.setp(ax1.get_yticklabels(), fontweight='bold')

        # Increase border thickness of the plot
        ax1.spines['bottom'].set_linewidth(2)
        ax1.spines['left'].set_linewidth(2)
        ax1.spines['top'].set_linewidth(2)
        ax1.spines['right'].set_linewidth(2)

        # Create a second y-axis for the interpretability line plot
        ax2 = ax1.twinx()
        ax2.plot(x, interpretability_scores, color='darkblue', marker='o', label='Interpretability')
        ax2.set_ylabel('Interpretability', fontsize=16, weight='bold')

        # Increase border thickness of the plot
        ax2.spines['bottom'].set_linewidth(2)
        ax2.spines['left'].set_linewidth(2)
        ax2.spines['top'].set_linewidth(2)
        ax2.spines['right'].set_linewidth(2)

        # Set the interpretability scale from 1 to 10
        ax2.set_ylim(0, 11.9)
        ax2.tick_params(axis='y', labelsize=14, width=3)
        plt.setp(ax2.get_yticklabels(), fontweight='bold')

        ax2.legend(loc='upper right', fontsize=14)

        # Set x-axis labels
        ax1.set_xticks(x)
        ax1.set_xticklabels(biovil_unseen_auc.columns, ha="center", fontweight='bold')
        ax1.tick_params(axis='x', pad=20, labelsize=16)
        ax1.tick_params(axis='y', labelsize=16)
        ax2.tick_params(axis='y', labelsize=16)


    # Adjust layout to avoid overlap
    plt.tight_layout()

    # Save the stacked plot to the output directory
    plot_filename = output_dir / "unseen_auc_interpretability_plots.png"
    plt.savefig(plot_filename)

    plt.show()


if __name__ == "__main__":
    summary_dir, output_dir = get_directories()
    biovil_unseen_auc, medklip_unseen_auc, kad_unseen_auc, interpretability_df = load_and_prepare_data(
        summary_dir)

    # Create the stacked bar and line plots
    create_stacked_bar_line_plots(biovil_unseen_auc, medklip_unseen_auc, kad_unseen_auc,
                                  interpretability_df, output_dir)
