# Copyright 2024-2025
# [ANONYMIZED_INSTITUTION],
# [ANONYMIZED_FACULTY],
# [ANONYMIZED_DEPARTMENT]
#
# Authors:
# AUTHOR_1 (author1@example.com)
# AUTHOR_2 (author2@example.com)
#
# Code generation tools and workflows:
# First versions of this code were potentially generated
# with the help of AI writing assistants including
# GitHub Copilot, ChatGPT, Microsoft Copilot, Google Gemini.
# Afterwards, the generated segments were manually reviewed and edited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Plots of local estimates over model checkpoints with different seeds."""

import copy
import logging
import pathlib
from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import chain, cycle
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import rcParams
from matplotlib.ticker import FixedLocator, FuncFormatter
from tqdm import tqdm

from topollm.data_processing.dictionary_handling import (
    filter_list_of_dictionaries_by_key_value_pairs,
    generate_fixed_parameters_text_from_dict,
)
from topollm.logging.log_dataframe_info import log_dataframe_info
from topollm.path_management.convert_object_to_valid_path_part import convert_list_to_path_part
from topollm.path_management.embeddings.protocol import EmbeddingsPathManager
from topollm.plotting.plot_size_config import AxisLimits, OutputDimensions, PlotSizeConfigFlat, PlotSizeConfigNested
from topollm.task_performance_analysis.plotting.distribution_violinplots_and_distribution_boxplots import TicksAndLabels
from topollm.task_performance_analysis.plotting.parameter_combinations_and_loaded_data_handling import (
    add_base_model_data,
    construct_mean_plots_over_model_checkpoints_output_dir_from_filter_key_value_pairs,
    derive_base_model_partial_name,
    get_fixed_parameter_combinations,
)
from topollm.task_performance_analysis.plotting.score_loader.score_loader import (
    EmotionClassificationScoreLoader,
    TrippyRScoreLoader,
)
from topollm.typing.enums import Verbosity

default_logger: logging.Logger = logging.getLogger(
    name=__name__,
)

# --------------------------------------------------------------------------- #
# Color palette for secondary-axis metrics.
# Keys must match the strings in `scores_data.columns_to_plot`.
# --------------------------------------------------------------------------- #
METRIC_COLORS: dict[str, str] = {
    "loss": "#2ca02c",  # green - used consistently for losses
    "train_loss": "#2ca02c",  # green - used consistently for losses
    "validation_loss": "#2ca02c",  # green - used consistently for losses
    "test_loss": "#2ca02c",  # green - used consistently for losses
    # Colors for the Trippy-R performance measures:
    "jga": "#ff7f0e",  # orange - used for performance measures (which are used for model selection)
    # Colors for the ERC performance measures:
    "Macro F1 (w/o Neutral)": "#d62728",  # red
    "Weighted F1 (w/o Neutral)": "#ff7f0e",  # orange - used for performance measures
    # Other:
    "accuracy": "#1f77b4",  # blue
    "recall": "#d62728",  # red
}

COLUMN_NAMES_TO_LEGEND_LABEL: dict[str, str] = {
    "loss": "Loss",
    "train_loss": "Loss",
    "validation_loss": "Loss",
    "test_loss": "Loss",
    # Labels for the Trippy-R performance measures:
    "jga": "Joint Goal Accuracy",
    # Labels for the ERC performance measures:
    "Weighted F1 (w/o Neutral)": "Weighted F1",
    "Macro F1 (w/o Neutral)": "Macro F1",
    # Other:
    "accuracy": "Accuracy",
    "recall": "Recall",
}

# Build a repeatable fallback cycle from Matplotlib's default prop-cycle
_fallback_cycle = cycle(rcParams["axes.prop_cycle"].by_key()["color"])


def get_metric_color(
    metric: str,
) -> str:
    """Return a fixed colour for `metric`, or a deterministic fallback."""
    return METRIC_COLORS.get(  # type: ignore - we always return a non-none value
        metric,
        next(_fallback_cycle),
    )


def get_metric_legend_label(
    metric: str,
) -> str:
    """Return a fixed legend label for `metric`, or a deterministic fallback."""
    return COLUMN_NAMES_TO_LEGEND_LABEL.get(  # type: ignore - we always return a non-none value
        metric,
        metric,
    )


@dataclass
class ScoresData:
    """Container for scores data."""

    df: pd.DataFrame | None
    columns_to_plot: list[str] | None

    def save_df(
        self,
        save_dir: pathlib.Path,
        verbosity: Verbosity = Verbosity.NORMAL,
        logger: logging.Logger = default_logger,
    ) -> None:
        if self.df is None:
            logger.warning(
                msg="No scores data available to save.",
            )
            logger.info(
                msg="Skipping saving scores data.",
            )
            return

        scores_df_save_path = pathlib.Path(
            save_dir,
            "scores_df.csv",
        )
        scores_df_save_path.parent.mkdir(
            parents=True,
            exist_ok=True,
        )

        if verbosity >= Verbosity.NORMAL:
            logger.info(
                msg=f"Saving combined scores to {scores_df_save_path = } ...",  # noqa: G004 - low overhead
            )
        self.df.to_csv(
            path_or_buf=scores_df_save_path,
            index=False,
        )
        if verbosity >= Verbosity.NORMAL:
            logger.info(
                msg=f"Saving combined scores to {scores_df_save_path = } DONE",  # noqa: G004 - low overhead
            )


@dataclass
class PlotInputData:
    """Container for plot input data."""

    local_estimates_df: pd.DataFrame
    scores: ScoresData


@dataclass
class PlotConfig:
    """Container for plot configuration."""

    ticks_and_labels: TicksAndLabels
    plot_size_config_nested: PlotSizeConfigNested
    seeds: np.ndarray
    x_column_name: str = "model_checkpoint"
    filter_key_value_pairs: dict = field(
        default_factory=dict,
    )
    base_model_model_partial_name: str | None = None
    plots_output_dir: pathlib.Path | None = None
    show_plots: bool = False

    add_legend: bool = True  # Add legend to the plot
    publication_ready: bool = False  # Exclude additional debug information in the publication ready version


def create_mean_plots_over_model_checkpoints_with_different_seeds(
    loaded_data: list[dict],
    array_key_name: str,
    output_root_dir: pathlib.Path,
    plot_size_configs_list: list[PlotSizeConfigFlat],
    embeddings_path_manager: EmbeddingsPathManager,
    *,
    restrict_to_model_seeds: list[int] | None = None,
    maximum_x_value: int | None = None,
    fixed_keys: list[str] | None = None,
    additional_fixed_params: dict[str, Any] | None = None,
    save_plot_raw_data: bool = True,
    publication_ready: bool = False,
    add_legend: bool = True,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> None:
    """Create mean plots over model checkpoints with different seeds."""
    if fixed_keys is None:
        fixed_keys = [
            # Notes:
            # - Do NOT fix the model seed, as we want to plot the mean over different seeds.
            # - If you want plots that combine estimates for different data subsamplings,
            #   you need to remove "data_subsampling_full" from the fixed_keys list and add only the split.
            "data_full",
            # > Example value for "data_subsampling_full": 'split=test_samples=7000_sampling=random_sampling-seed=41'
            # "data_subsampling_full",
            # > Example value for "data_subsampling_split": 'test'
            "data_subsampling_split",
            "data_dataset_seed",
            "model_layer",  # model_layer needs to be an integer
            "model_partial_name",
            "local_estimates_desc_full",
        ]

    if additional_fixed_params is None:
        additional_fixed_params = {
            "tokenizer_add_prefix_space": "False",  # tokenizer_add_prefix_space needs to be a string
        }

    # Iterate over fixed parameter combinations.
    combinations = list(
        get_fixed_parameter_combinations(
            loaded_data=loaded_data,
            fixed_keys=fixed_keys,
            additional_fixed_params=additional_fixed_params,
        ),
    )
    total_combinations = len(combinations)

    if verbosity >= Verbosity.NORMAL:
        # Log available options
        for fixed_param in fixed_keys:
            options = {entry[fixed_param] for entry in loaded_data if fixed_param in entry}
            logger.info(
                msg=f"{fixed_param=} options: {options=}",  # noqa: G004 - low overhead
            )

    for filter_key_value_pairs in tqdm(
        iterable=combinations,
        total=total_combinations,
        desc="Plotting different choices for model checkpoints",
    ):
        filtered_data: list[dict] = filter_list_of_dictionaries_by_key_value_pairs(
            list_of_dicts=loaded_data,
            key_value_pairs=filter_key_value_pairs,
        )

        if len(filtered_data) == 0:
            logger.warning(
                msg=f"No data found for {filter_key_value_pairs = }.",  # noqa: G004 - low overhead
            )
            logger.warning(
                msg="Skipping this combination of parameters.",
            )
            continue

        # The identifier of the base model.
        # This value will be used to select the models for the correlation analysis
        # and add the estimates of the base model for the model checkpoint analysis.
        model_partial_name = filter_key_value_pairs["model_partial_name"]
        base_model_model_partial_name: str = derive_base_model_partial_name(
            model_partial_name=model_partial_name,
        )

        filtered_data_with_added_base_model: list[dict] = add_base_model_data(
            loaded_data=loaded_data,
            base_model_model_partial_name=base_model_model_partial_name,
            filter_key_value_pairs=filter_key_value_pairs,
            filtered_data=filtered_data,
            logger=logger,
        )

        # Sort the arrays by increasing model checkpoint.
        # Then from this point, the list of arrays and list of extracted checkpoints will be in the correct order.

        # 1. Step: Replace None model checkpoints with -1.
        model_checkpoint_column_name = "model_checkpoint"

        for single_dict in filtered_data_with_added_base_model:
            if single_dict[model_checkpoint_column_name] is None:
                single_dict[model_checkpoint_column_name] = -1

        # 2. Step: Call sorting function.
        sorted_data: list[dict] = sorted(
            filtered_data_with_added_base_model,
            key=lambda single_dict: int(single_dict[model_checkpoint_column_name]),
        )

        sorted_data_df = pd.DataFrame(
            data=sorted_data,
        )

        model_checkpoint_str_list: list[str] = [
            str(object=single_dict[model_checkpoint_column_name]) for single_dict in sorted_data
        ]

        # ========================================================== #
        # START: Load the corresponding model performance metrics

        scores_data: ScoresData = load_scores(
            embeddings_path_manager=embeddings_path_manager,
            filter_key_value_pairs=filter_key_value_pairs,
            verbosity=verbosity,
            logger=logger,
        )

        # END: Load the corresponding model performance metrics
        # ========================================================== #

        # # # #
        # Compute means

        # Create column with means
        sorted_data_df[f"{array_key_name}_mean"] = sorted_data_df[array_key_name].apply(
            func=lambda x: np.mean(x),
        )

        if verbosity >= Verbosity.NORMAL:
            log_dataframe_info(
                df=sorted_data_df,
                df_name="sorted_data_df",
                logger=logger,
            )

        # # # #
        # Save locations and saving the data

        plots_output_dir: pathlib.Path = (
            construct_mean_plots_over_model_checkpoints_output_dir_from_filter_key_value_pairs(
                output_root_dir=output_root_dir,
                filter_key_value_pairs=filter_key_value_pairs,
                verbosity=verbosity,
                logger=logger,
            )
        )

        # Save the sorted data list of dicts with the arrays to a pickle file.
        if save_plot_raw_data:
            plot_raw_data_save_dir = pathlib.Path(
                plots_output_dir,
                "raw_data",
            )
            plot_raw_data_save_dir.mkdir(
                parents=True,
                exist_ok=True,
            )

            sorted_data_df_save_path = pathlib.Path(
                plot_raw_data_save_dir,
                "sorted_data_df.csv",
            )

            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg=f"Saving sorted data to {sorted_data_df_save_path = } ...",  # noqa: G004 - low overhead
                )
            sorted_data_df.to_csv(
                path_or_buf=sorted_data_df_save_path,
                index=False,
            )
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg=f"Saving sorted data to {sorted_data_df_save_path = } DONE",  # noqa: G004 - low overhead
                )

            # Save the scores data if available
            scores_data.save_df(
                save_dir=plot_raw_data_save_dir,
                verbosity=verbosity,
                logger=logger,
            )

        match publication_ready:
            case False:
                ylabel = array_key_name
            case True:
                ylabel = "Mean local dimension"

        match model_partial_name:
            case (
                "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug=-1_use_context=False"
                | "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug--1_ep-50_use_context-False"
            ):
                ticks_and_labels: TicksAndLabels = TicksAndLabels(
                    xlabel="Epoch",
                    ylabel=ylabel,
                    xticks_labels=model_checkpoint_str_list,
                )
            case _:
                ticks_and_labels: TicksAndLabels = TicksAndLabels(
                    xlabel="Steps",
                    ylabel=ylabel,
                    xticks_labels=model_checkpoint_str_list,
                )

        # # # #
        # Create plots
        for plot_size_config in plot_size_configs_list:
            # Convert the PlotSizeConfigFlat objects into the new dataclass format.

            secondary_axis_limits_list: list[AxisLimits] = [
                AxisLimits(),  # Automatic scaling
                AxisLimits(
                    y_min=0.0,
                    y_max=1.1,
                ),
                # This is for the measures in the Trippy-R setup
                AxisLimits(
                    y_min=0.0,
                    y_max=0.8,
                ),
                # This is for the measures in the Emotion setup:
                AxisLimits(
                    y_min=0.7,
                    y_max=0.8,
                ),
                AxisLimits(
                    y_min=0.1,
                    y_max=0.95,
                ),
                AxisLimits(
                    y_min=0.25,
                    y_max=0.8,
                ),
            ]

            # For publication ready plots, make the dimensions smaller so that the text is easier to read.
            match publication_ready:
                case False:
                    (
                        output_pdf_width,
                        output_pdf_height,
                    ) = (
                        2_500,
                        1_500,
                    )
                case True:
                    (
                        output_pdf_width,
                        output_pdf_height,
                    ) = (
                        500,
                        300,
                    )

            for secondary_axis_limits in secondary_axis_limits_list:
                plot_size_config_nested = PlotSizeConfigNested(
                    primary_axis_limits=AxisLimits(
                        x_min=plot_size_config.x_min,
                        x_max=plot_size_config.x_max,
                        y_min=plot_size_config.y_min,
                        y_max=plot_size_config.y_max,
                    ),
                    secondary_axis_limits=secondary_axis_limits,
                    output_dimensions=OutputDimensions(
                        output_pdf_width=output_pdf_width,
                        output_pdf_height=output_pdf_height,
                    ),
                )

                plot_local_estimates_with_individual_seeds_and_aggregated_over_seeds(
                    local_estimates_df=sorted_data_df,
                    ticks_and_labels=ticks_and_labels,
                    plot_size_config_nested=plot_size_config_nested,
                    scores_data=scores_data,
                    restrict_to_model_seeds=restrict_to_model_seeds,
                    maximum_x_value=maximum_x_value,
                    x_column_name=model_checkpoint_column_name,
                    filter_key_value_pairs=filter_key_value_pairs,
                    base_model_model_partial_name=base_model_model_partial_name,
                    plots_output_dir=plots_output_dir,
                    publication_ready=publication_ready,
                    add_legend=add_legend,
                    verbosity=verbosity,
                    logger=logger,
                )


def load_scores(
    embeddings_path_manager: EmbeddingsPathManager,
    filter_key_value_pairs: dict,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> ScoresData:
    match filter_key_value_pairs["model_partial_name"]:
        # Notes:
        # - For the EmoLoop models, you need to have prepared the parsed_data files
        case (
            "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug=-1_use_context=False"
            | "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug--1_ep-50_use_context-False"
        ):
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg="Loading scores for EmoLoop emotion models.",
                )

            match filter_key_value_pairs["model_partial_name"]:
                case "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug=-1_use_context=False":
                    base_directory = pathlib.Path(
                        embeddings_path_manager.data_dir,
                        "models/EmoLoop/output_dir/",
                        "debug=-1/use_context=False/ep=5/",
                    )
                case "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug--1_ep-50_use_context-False":
                    base_directory = pathlib.Path(
                        embeddings_path_manager.data_dir,
                        "models/EmoLoop/output_dir/",
                        "debug=-1/use_context=False/ep=50/",
                    )
                case _:
                    raise ValueError(
                        msg=f"Unknown model_partial_name: {filter_key_value_pairs['model_partial_name']}",
                    )

            seed_dfs: list[pd.DataFrame] = []
            columns_to_plot_set: set[str] = set()

            for seed in range(50, 54):
                parsed_data_path: pathlib.Path = pathlib.Path(
                    base_directory,
                    f"seed={seed}/",
                    "parsed_data/raw_data/parsed_data.csv",
                )

                file_loader = EmotionClassificationScoreLoader(
                    filepath=parsed_data_path,
                )
                scores_df: pd.DataFrame = file_loader.get_scores()
                scores_df["model_seed"] = seed  # Tag the dataframe with the current seed
                seed_dfs.append(scores_df)

                columns_to_plot: list[str] = file_loader.get_columns_to_plot()
                columns_to_plot_set.update(
                    columns_to_plot,
                )

            # Concatenate all seed dataframes into one
            if len(seed_dfs) == 0:
                logger.warning(
                    msg="No seed dataframes found.",
                )
                logger.info(
                    msg="Setting combined_scores_df to None for this model.",
                )
                combined_scores_df: pd.DataFrame | None = None
            else:
                combined_scores_df: pd.DataFrame | None = pd.concat(
                    objs=seed_dfs,
                    ignore_index=True,
                )

            combined_scores_columns_to_plot_list: list[str] | None = list(columns_to_plot_set)
        case "model=roberta-base-trippy_r_multiwoz21":
            seed_dfs: list[pd.DataFrame] = []
            columns_to_plot_set: set[str] = set()

            combined_scores_df = None
            combined_scores_columns_to_plot_list = None

            for seed in range(40, 45):
                results_folder_for_given_seed_path: pathlib.Path = pathlib.Path(
                    embeddings_path_manager.data_dir,
                    f"models/trippy_r_checkpoints/multiwoz21/all_checkpoints/results.{seed}",
                )

                file_loader = TrippyRScoreLoader(
                    results_folder_for_given_seed_path=results_folder_for_given_seed_path,
                    verbosity=verbosity,
                    logger=logger,
                )

                scores_df: pd.DataFrame = file_loader.get_scores()
                scores_df["model_seed"] = seed  # Tag the dataframe with the current seed
                seed_dfs.append(scores_df)

                columns_to_plot: list[str] = file_loader.get_columns_to_plot()
                columns_to_plot_set.update(
                    columns_to_plot,
                )

            # Concatenate all seed dataframes into one
            if len(seed_dfs) == 0:
                logger.warning(
                    msg="No seed dataframes found.",
                )
                logger.info(
                    msg="Setting combined_scores_df to None for this model.",
                )
                combined_scores_df: pd.DataFrame | None = None
            else:
                combined_scores_df: pd.DataFrame | None = pd.concat(
                    objs=seed_dfs,
                    ignore_index=True,
                )

            combined_scores_columns_to_plot_list: list[str] | None = list(columns_to_plot_set)
        case "model=roberta-base-trippy_r_multiwoz21_50-0.020-constant_schedule_with_warmup":
            seed_dfs: list[pd.DataFrame] = []
            columns_to_plot_set: set[str] = set()

            combined_scores_df = None
            combined_scores_columns_to_plot_list = None

            for seed in chain(
                range(40, 45),
                [1111],
            ):
                results_folder_for_given_seed_path: pathlib.Path = pathlib.Path(
                    embeddings_path_manager.data_dir,
                    "models/trippy_r_checkpoints/multiwoz21/all_checkpoints/",
                    "model_output/num_train_epochs=50/warmup_proportion=0.020/lr_scheduler_type=constant_schedule_with_warmup",
                    f"results.{seed}",
                )

                file_loader = TrippyRScoreLoader(
                    results_folder_for_given_seed_path=results_folder_for_given_seed_path,
                    verbosity=verbosity,
                    logger=logger,
                )

                scores_df: pd.DataFrame = file_loader.get_scores()
                scores_df["model_seed"] = seed  # Tag the dataframe with the current seed
                seed_dfs.append(scores_df)

                columns_to_plot: list[str] = file_loader.get_columns_to_plot()
                columns_to_plot_set.update(
                    columns_to_plot,
                )

            # Concatenate all seed dataframes into one
            if len(seed_dfs) == 0:
                logger.warning(
                    msg="No seed dataframes found.",
                )
                logger.info(
                    msg="Setting combined_scores_df to None for this model.",
                )
                combined_scores_df: pd.DataFrame | None = None
            else:
                combined_scores_df: pd.DataFrame | None = pd.concat(
                    objs=seed_dfs,
                    ignore_index=True,
                )

            combined_scores_columns_to_plot_list: list[str] | None = list(columns_to_plot_set)
        case "model=roberta-base-trippy_r_multiwoz21_50-0.020-linear_schedule_with_warmup":
            seed_dfs: list[pd.DataFrame] = []
            columns_to_plot_set: set[str] = set()

            combined_scores_df = None
            combined_scores_columns_to_plot_list = None

            for seed in chain(
                range(40, 45),
                [1111],
            ):
                results_folder_for_given_seed_path: pathlib.Path = pathlib.Path(
                    embeddings_path_manager.data_dir,
                    "models/trippy_r_checkpoints/multiwoz21/all_checkpoints/",
                    "model_output/num_train_epochs=50/warmup_proportion=0.020/lr_scheduler_type=linear_schedule_with_warmup",
                    f"results.{seed}",
                )

                file_loader = TrippyRScoreLoader(
                    results_folder_for_given_seed_path=results_folder_for_given_seed_path,
                    verbosity=verbosity,
                    logger=logger,
                )

                scores_df: pd.DataFrame = file_loader.get_scores()
                scores_df["model_seed"] = seed  # Tag the dataframe with the current seed
                seed_dfs.append(scores_df)

                columns_to_plot: list[str] = file_loader.get_columns_to_plot()
                columns_to_plot_set.update(
                    columns_to_plot,
                )

            # Concatenate all seed dataframes into one
            if len(seed_dfs) == 0:
                logger.warning(
                    msg="No seed dataframes found.",
                )
                logger.info(
                    msg="Setting combined_scores_df to None for this model.",
                )
                combined_scores_df: pd.DataFrame | None = None
            else:
                combined_scores_df: pd.DataFrame | None = pd.concat(
                    objs=seed_dfs,
                    ignore_index=True,
                )

            combined_scores_columns_to_plot_list: list[str] | None = list(columns_to_plot_set)
        case _:
            logger.warning(
                msg=f"No specific model performance data loader implemented for "  # noqa: G004 - low overhead
                f"{filter_key_value_pairs['model_partial_name'] = }.",
            )
            logger.info(
                msg="Setting combined_scores_df to None for this model.",
            )
            combined_scores_df: pd.DataFrame | None = None
            combined_scores_columns_to_plot_list: list[str] | None = None
        # Note: This is where you would implement score loading for language models (with performance given by loss)

    scores_data = ScoresData(
        df=combined_scores_df,
        columns_to_plot=combined_scores_columns_to_plot_list,
    )

    return scores_data


def get_data_subsampling_split_from_data_subsampling_full(
    data_subsampling_full: str,
) -> str:
    """Extract the split from the full description of the data subsampling.

    For example, from:
    - 'split=test_samples=10000_sampling=random_sampling-seed=778' we extract 'test'.
    """
    split: str = data_subsampling_full.split(
        sep="_",
    )[0].split(
        sep="=",
    )[1]

    return split


def plot_local_estimates_with_individual_seeds_and_aggregated_over_seeds(
    local_estimates_df: pd.DataFrame,
    ticks_and_labels: TicksAndLabels,
    plot_size_config_nested: PlotSizeConfigNested,
    scores_data: ScoresData,
    *,
    restrict_to_model_seeds: list[int] | None = None,
    maximum_x_value: int | None = None,
    x_column_name: str = "model_checkpoint",
    filter_key_value_pairs: dict,
    base_model_model_partial_name: str | None = None,
    plots_output_dir: pathlib.Path | None = None,
    publication_ready: bool = False,
    add_legend: bool = True,
    do_create_seedwise_estimate_visualization: bool = False,
    show_plots: bool = False,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> None:
    """Plot local estimates for each model seed and a summary plot.

    Args:
        df:
            Input dataframe with columns:
            - 'model_checkpoint',
            - 'model_seed',
            - 'file_data_mean'

    """
    model_seed_column_name = "model_seed"

    # # # #
    # Make deep copies of the plotting data so that we do not modify the original data
    local_estimates_df = local_estimates_df.copy(deep=True)
    scores_data = copy.deepcopy(scores_data)

    # # # #
    # Pre-process the local estimates data

    local_estimates_plot_data_df: pd.DataFrame = local_estimates_df[
        [
            x_column_name,
            model_seed_column_name,
            "file_data_mean",
        ]
    ]

    if restrict_to_model_seeds is not None:
        # Restrict the data to the given model seeds
        local_estimates_plot_data_df = local_estimates_plot_data_df[
            local_estimates_plot_data_df[model_seed_column_name].isin(
                values=restrict_to_model_seeds,
            )
            | local_estimates_plot_data_df[model_seed_column_name].isna()
        ]

    seeds: np.ndarray = local_estimates_plot_data_df[model_seed_column_name].dropna().unique().astype(dtype=int)

    # Separate the checkpoint -1 data (no seeds associated)
    checkpoint_neg1_selected_rows = local_estimates_plot_data_df[local_estimates_plot_data_df[x_column_name] == -1]

    if checkpoint_neg1_selected_rows.empty:
        logger.warning(
            msg="No checkpoint -1 data found in the local estimates DataFrame.",
        )
        logger.warning(
            msg="We will not add any checkpoint -1 data to the plot.",
        )
        neg1_data_emulated = None
    else:
        checkpoint_neg1 = checkpoint_neg1_selected_rows.iloc[0]

        # Emulate checkpoint -1 data for each seed
        neg1_data_emulated = pd.DataFrame(
            data={
                x_column_name: [-1] * len(seeds),
                model_seed_column_name: seeds,
                "file_data_mean": [checkpoint_neg1["file_data_mean"]] * len(seeds),
            },
        )

    # Drop original -1 checkpoint and append the emulated data
    local_estimates_plot_data_df = local_estimates_plot_data_df[
        local_estimates_plot_data_df[x_column_name] != -1
    ].dropna()
    local_estimates_plot_data_df = pd.concat(
        objs=[
            neg1_data_emulated,
            local_estimates_plot_data_df,
        ],
        ignore_index=True,
    )

    # Ensure correct types
    local_estimates_plot_data_df = local_estimates_plot_data_df.astype(
        dtype={
            x_column_name: int,
            model_seed_column_name: int,
            "file_data_mean": float,
        },
    )

    # # # #
    # Pre-process the scores data

    if "data_subsampling_split" in filter_key_value_pairs:
        if verbosity >= Verbosity.NORMAL:
            logger.info(
                "Getting data_subsampling_split directly from filter_key_value_pairs.",
            )
        data_subsampling_split: str = filter_key_value_pairs["data_subsampling_split"]
    elif "data_subsampling_full" in filter_key_value_pairs:
        if verbosity >= Verbosity.NORMAL:
            logger.info(
                "Deriving data_subsampling_split from data_subsampling_full.",
            )
        data_subsampling_split: str = get_data_subsampling_split_from_data_subsampling_full(
            data_subsampling_full=filter_key_value_pairs["data_subsampling_full"],
        )
    else:
        logger.warning(
            msg="No 'data_subsampling_split' key found in filter_key_value_pairs "
            "and cannot be derived from 'data_subsampling_full'.",
        )
        logger.info(
            msg="Skipping this plot call and returning from function now.",
        )
        return

    if verbosity >= Verbosity.NORMAL:
        logger.info(
            msg=f"Filtering scores_df based on {data_subsampling_split = } ...",  # noqa: G004 - low overhead
        )

    if scores_data.df is not None:
        if "data_subsampling_split" not in scores_data.df.columns:
            logger.warning(
                msg="No data_subsampling_split column found in scores_data.df.",
            )
            logger.info(
                msg="Will not modify the scores_df.",
            )
        else:
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg=f"Filtering scores_df based on {data_subsampling_split = }",  # noqa: G004 - low overhead
                )
                logger.info(
                    msg=f"Shape before filtering: {scores_data.df.shape = }",  # noqa: G004 - low overhead
                )

            # Filter the scores_df based on the data_subsampling_split
            scores_data.df = scores_data.df[scores_data.df["data_subsampling_split"] == data_subsampling_split]

            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg=f"Shape after filtering: {scores_data.df.shape = }",  # noqa: G004 - low overhead
                )

    if verbosity >= Verbosity.NORMAL:
        logger.info(
            msg=f"Filtering scores_df based on {data_subsampling_split = } DONE",  # noqa: G004 - low overhead
        )

    # # # #
    # Restrict the scores data to the given model seeds
    if restrict_to_model_seeds is not None and scores_data.df is not None:
        scores_data.df = scores_data.df[scores_data.df["model_seed"].isin(restrict_to_model_seeds)]

    # # # #
    # Set the maximum x value for the plot
    if maximum_x_value is not None:
        local_estimates_plot_data_df = local_estimates_plot_data_df[
            local_estimates_plot_data_df[x_column_name] <= maximum_x_value
        ]
    if scores_data.df is not None:
        scores_data.df = scores_data.df[scores_data.df[x_column_name] <= maximum_x_value]

    # Increase epoch number by 1 for the ERC models on the x-axis labels
    # (so that the first epoch is 1 instead of 0)
    model_partial_name = filter_key_value_pairs["model_partial_name"]
    match model_partial_name:
        case (
            "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug=-1_use_context=False"
            | "model=bert-base-uncased-ContextBERT-ERToD_emowoz_basic_setup_debug--1_ep-50_use_context-False"
        ):
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg="Increasing x-axis values by 1 for ERC models.",
                )
            local_estimates_plot_data_df[x_column_name] = local_estimates_plot_data_df[x_column_name] + 1
            if scores_data.df is not None:
                scores_data.df[x_column_name] = scores_data.df[x_column_name] + 1
        case _:
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg="Using default x-axis labels.",
                )

    # # # #
    # Add additional information to the output path
    if plots_output_dir is not None:
        plots_output_dir = pathlib.Path(
            plots_output_dir,
            f"{publication_ready=}",
            f"{add_legend=}",
            f"restrict_to_model_seeds={convert_list_to_path_part(input_list=restrict_to_model_seeds)}",
        )

    # # # #
    # Create the containers for the plot data and plot configuration

    plot_input_data = PlotInputData(
        local_estimates_df=local_estimates_plot_data_df,
        scores=scores_data,
    )

    plot_config = PlotConfig(
        ticks_and_labels=ticks_and_labels,
        plot_size_config_nested=plot_size_config_nested,
        seeds=seeds,
        x_column_name=x_column_name,
        filter_key_value_pairs=filter_key_value_pairs,
        base_model_model_partial_name=base_model_model_partial_name,
        plots_output_dir=plots_output_dir,
        show_plots=show_plots,
        add_legend=add_legend,
        publication_ready=publication_ready,
    )

    # ========================================================= #
    # Plots: Individual by seed using a figure and axis.
    # ========================================================= #
    if do_create_seedwise_estimate_visualization:
        create_seedwise_estimate_visualization(
            data=plot_input_data,
            config=plot_config,
            verbosity=verbosity,
            logger=logger,
        )

    # ========================================================= #
    # Plots: Aggregated over seeds
    # ========================================================= #

    create_aggregate_estimate_visualization(
        data=plot_input_data,
        config=plot_config,
        verbosity=verbosity,
        logger=logger,
    )


def label_every_n(
    ax,  # noqa: ANN001 - problem with plt.Axes type
    axis: str = "x",
    keep_every: int = 2,
    *,
    tick_positions: Sequence[float] | None = None,
) -> None:
    """Show a label on every `keep_every`-th major tick."""
    if tick_positions is None:
        tick_positions = ax.get_xticks() if axis == "x" else ax.get_yticks()

    locator = FixedLocator(
        tick_positions,  # type: ignore - problem with type
    )
    formatter = FuncFormatter(lambda value, pos: f"{value:g}" if (pos is None) or (pos % keep_every == 0) else "")

    target_axis = ax.xaxis if axis == "x" else ax.yaxis
    target_axis.set_major_locator(locator)
    target_axis.set_major_formatter(formatter)


def create_aggregate_estimate_visualization(
    data: PlotInputData,
    config: PlotConfig,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> None:
    """Create a plot of the mean local estimates over checkpoints with standard deviation bands."""
    # Set parameters based on draft or publication mode
    match config.publication_ready:
        case False:
            grid_alpha: float = 1.0
            local_estimates_label = "Mean file_data_mean across seeds"
            legend_title = "Legend"
        case True:
            grid_alpha: float = 0.25  # Make the grid weaker in publication-ready mode
            local_estimates_label = "Mean local dim."
            legend_title = None
        case _:
            raise ValueError(
                msg=f"Unknown value for {config.publication_ready = }",
            )

    # Set the marker size depending on the number of values on the x-axis
    if len(data.local_estimates_df[config.x_column_name].unique()) > 30:
        # If there are more than 20 unique values, use a smaller marker size
        markersize: int = 4
    else:
        markersize: int = 6

    if "index" in data.local_estimates_df.columns:
        data.local_estimates_df = data.local_estimates_df.drop(
            columns=["index"],
        )

    # Create summary for local estimates with mean and standard deviation
    summary_local_estimates: pd.DataFrame = (
        data.local_estimates_df.groupby(
            by=config.x_column_name,
            as_index=False,
        )["file_data_mean"]
        .agg(
            func=[
                "mean",
                "std",
            ],
        )
        .reset_index()
    )

    if summary_local_estimates.empty:
        logger.warning(
            msg="No data available for plotting.",
        )
        logger.info(
            msg="Skipping this plot call and returning from function now.",
        )
        return

    if "index" in summary_local_estimates.columns:
        summary_local_estimates = summary_local_estimates.drop(
            columns=["index"],
        )

    summary_local_estimates.columns = [
        config.x_column_name,
        "mean",
        "std",
    ]

    # Explicitly handle NaNs in standard deviation (set to 0)
    summary_local_estimates["std"] = summary_local_estimates["std"].fillna(0)

    # Convert to NumPy arrays explicitly for matplotlib
    checkpoints = summary_local_estimates[config.x_column_name].to_numpy()
    means = summary_local_estimates["mean"].to_numpy()
    stds = summary_local_estimates["std"].to_numpy()

    # Create a summary figure and axis
    (
        fig,
        ax1,
    ) = plt.subplots(
        figsize=(
            config.plot_size_config_nested.output_dimensions.output_pdf_width / 100,
            config.plot_size_config_nested.output_dimensions.output_pdf_height / 100,
        ),
    )
    ax1.plot(
        checkpoints,
        means,
        marker="o",
        markersize=markersize,
        color="blue",
        label=local_estimates_label,
    )
    ax1.fill_between(
        x=checkpoints,
        y1=means - stds,
        y2=means + stds,
        color="blue",
        alpha=0.2,
        # No label for the standard deviation
    )

    match config.publication_ready:
        case False:
            ax1.set_title(
                label="Mean Local Estimates Over Checkpoints with Standard Deviation Band",
            )
        case True:
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg="Skipping the title in the plot for publication-ready mode.",
                )
        case _:
            raise ValueError(
                msg=f"Unknown value for {config.publication_ready = }",
            )

    ax1.grid(
        visible=True,
        alpha=grid_alpha,
    )

    ax1.set_xlabel(
        xlabel=config.ticks_and_labels.xlabel,
    )
    ax1.set_ylabel(
        ylabel=config.ticks_and_labels.ylabel,
    )

    # # # #
    # Plot the additional data if available

    # Add second y-axis for scores
    ax2 = ax1.twinx()

    if data.scores.df is not None and data.scores.columns_to_plot is not None:
        # Note:
        # - summary_scores is a DataFrame with a multilevel index
        summary_scores: pd.DataFrame = (
            data.scores.df.groupby(
                by=config.x_column_name,
                as_index=False,
            )[data.scores.columns_to_plot]
            .agg(
                func=[
                    "mean",
                    "std",
                ],
            )
            .reset_index()
        )

        if verbosity >= Verbosity.NORMAL:
            log_dataframe_info(
                df=summary_scores,
                df_name="summary_scores",
                logger=logger,
            )

        for column in data.scores.columns_to_plot:
            if column not in summary_scores.columns:
                logger.warning(
                    msg=f"{column=} not found in summary_scores DataFrame.",  # noqa: G004
                )
                continue

            checkpoints = summary_scores[config.x_column_name].to_numpy()
            means = summary_scores[
                column,
                "mean",
            ].to_numpy()
            stds = summary_scores[
                column,
                "std",
            ].to_numpy()

            color = get_metric_color(
                metric=column,
            )
            match config.publication_ready:
                case False:
                    label: str = f"{column} (mean)"
                case True:
                    label = get_metric_legend_label(
                        metric=column,
                    )

            ax2.plot(
                checkpoints,
                means,
                linestyle="--",
                marker="x",
                markersize=markersize,
                label=label,
                color=color,
            )
            ax2.fill_between(
                x=checkpoints,
                y1=means - stds,
                y2=means + stds,
                alpha=0.2,
                color=color,
                # No label for the standard deviation
            )

    # Optional: Set axis label once
    ax2.set_ylabel(
        ylabel="Evaluation measures",
        color="tab:red",
    )  # Customize label and color if desired
    ax2.tick_params(
        axis="y",
        labelcolor="tab:red",
    )

    match config.publication_ready:
        case False:
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg="Will not modify the axes labels in the debug mode.",
                )
        case True:
            label_every_n(ax=ax1, axis="x", keep_every=2)  # show every 2-nd x-label
            # label_every_n(ax=ax1, axis="y", keep_every=2)  # show every 2-nd y-label
            # label_every_n(ax=ax2, axis="y", keep_every=2)  # show every 2-nd y-label

    match config.add_legend:
        case True:
            # Combine legends from both axes
            (
                lines_1,
                labels_1,
            ) = ax1.get_legend_handles_labels()
            (
                lines_2,
                labels_2,
            ) = ax2.get_legend_handles_labels()

            # This code would add the legend without de-duplication of the labels:
            # > ax1.legend(
            # >     handles=lines_1 + lines_2,
            # >     labels=labels_1 + labels_2,
            # >     title=legend_title,
            # > )

            # If a label occurs multiple times, remove it from the legend.
            # This for example might happen for the losses, if you have
            # "train_loss", "validation_loss", "test_loss"
            handles = lines_1 + lines_2
            labels = labels_1 + labels_2
            by_label = dict(
                zip(
                    labels,
                    handles,
                    strict=True,
                ),
            )
            # Place the legend on the second axis, so that it is drawn on top of potential curves
            legend = ax2.legend(
                handles=by_label.values(),
                labels=by_label.keys(),
                title=legend_title,
            )
            legend.set_zorder(
                level=200,  # Bring legend to the front
            )
            legend.get_frame().set_facecolor(
                color="white",
            )
            legend.get_frame().set_alpha(
                alpha=0.5,
            )
        case False:
            if verbosity >= Verbosity.NORMAL:
                logger.info(
                    msg="Skipping the legend in the plot.",
                )

    # Set the y-axis limits
    ax1 = config.plot_size_config_nested.primary_axis_limits.set_y_axis_limits(
        axis=ax1,
    )
    ax2 = config.plot_size_config_nested.secondary_axis_limits.set_y_axis_limits(
        axis=ax2,
    )

    fixed_params_text: str = generate_fixed_parameters_text_from_dict(
        filters_dict=config.filter_key_value_pairs,
    )

    if fixed_params_text is not None:
        match config.publication_ready:
            case False:
                # Add information about the fixed parameters into the plot
                ax1.text(
                    x=1.05,
                    y=0.25,
                    s=f"Fixed Parameters:\n{fixed_params_text}",
                    transform=plt.gca().transAxes,
                    fontsize=6,
                    verticalalignment="top",
                    bbox={
                        "boxstyle": "round",
                        "facecolor": "wheat",
                        "alpha": 0.3,
                    },
                )
            case True:
                if verbosity >= Verbosity.NORMAL:
                    logger.info(
                        msg="Skipping the fixed parameters text in the plot for publication-ready mode.",
                    )

    # Add info about the base model if available into the bottom left corner of the plot
    if config.base_model_model_partial_name is not None:
        match config.publication_ready:
            case False:
                ax1.text(
                    x=0.01,
                    y=0.01,
                    s=f"{config.base_model_model_partial_name=}",
                    transform=plt.gca().transAxes,
                    fontsize=6,
                    verticalalignment="bottom",
                    horizontalalignment="left",
                    bbox={
                        "boxstyle": "round",
                        "facecolor": "wheat",
                        "alpha": 0.3,
                    },
                )
            case True:
                if verbosity >= Verbosity.NORMAL:
                    logger.info(
                        msg="Skipping the information about the base model in publication-ready mode.",
                    )

    fig.tight_layout()

    # Make sure everything is written to the plot
    fig.canvas.draw()
    fig.canvas.flush_events()

    # Save the figure
    if config.plots_output_dir is not None:
        plot_name: str = f"local_estimates_aggregate_{config.plot_size_config_nested.y_range_description}"
        plot_save_path = pathlib.Path(
            config.plots_output_dir,
            "aggregate",
            f"{plot_name}.pdf",
        )

        if verbosity >= Verbosity.NORMAL:
            logger.info(
                msg=f"Saving plot to {plot_save_path = } ...",  # noqa: G004 - low overhead
            )

        plot_save_path.parent.mkdir(
            parents=True,
            exist_ok=True,
        )
        # Set `bbox_inches` and `pad_inches` to ensure the plot is saved without extra whitespace
        fig.savefig(
            fname=plot_save_path,
            bbox_inches="tight",
            pad_inches=0,
        )

        if verbosity >= Verbosity.NORMAL:
            logger.info(
                msg=f"Saving plot to {plot_save_path = } DONE",  # noqa: G004 - low overhead
            )

    if config.show_plots:
        fig.show()

    # Close the figure
    plt.close(fig)


def create_seedwise_estimate_visualization(
    data: PlotInputData,
    config: PlotConfig,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> None:
    """Visualize seed-wise estimates over model checkpoints."""
    (
        fig1,
        ax1,
    ) = plt.subplots(
        figsize=(
            config.plot_size_config_nested.output_dimensions.output_pdf_width / 100,
            config.plot_size_config_nested.output_dimensions.output_pdf_height / 100,
        ),
    )

    # # # #
    # Plot the mean of local estimates
    for seed in config.seeds:
        seed_data = data.local_estimates_df[data.local_estimates_df["model_seed"] == seed]
        ax1.plot(
            seed_data[config.x_column_name],
            seed_data["file_data_mean"],
            marker="o",
            label=f"{seed=}",
        )

    ax1.set_xlabel(
        xlabel=config.ticks_and_labels.xlabel,
    )
    ax1.set_ylabel(
        ylabel=config.ticks_and_labels.ylabel,
    )
    ax1.set_title(label="Local Estimates Over Checkpoints by Model Seed (including checkpoint -1)")

    ax1.grid(
        visible=True,
    )

    ax1.set_xlabel(
        xlabel=config.ticks_and_labels.xlabel,
    )
    ax1.set_ylabel(
        ylabel=config.ticks_and_labels.ylabel,
    )

    # # # #
    # Plot the additional data if available

    # Add second y-axis for scores
    ax2 = ax1.twinx()

    if data.scores.df is not None and data.scores.columns_to_plot is not None:
        for seed in config.seeds:
            seed_scores = data.scores.df[data.scores.df["model_seed"] == seed]
            for column in data.scores.columns_to_plot:
                ax2.plot(
                    seed_scores[config.x_column_name],
                    seed_scores[column],
                    linestyle="--",
                    marker="x",
                    label=f"{column} (seed={seed})",
                )

    # Optional: Set axis label once
    ax2.set_ylabel(
        ylabel="Scores",
        color="tab:red",
    )  # Customize label and color if desired
    ax2.tick_params(
        axis="y",
        labelcolor="tab:red",
    )

    # Combine legends from both axes
    (
        lines_1,
        labels_1,
    ) = ax1.get_legend_handles_labels()
    (
        lines_2,
        labels_2,
    ) = ax2.get_legend_handles_labels()
    ax1.legend(
        handles=lines_1 + lines_2,
        labels=labels_1 + labels_2,
        title="Legend",
    )

    # Set the y-axis limits
    ax1 = config.plot_size_config_nested.primary_axis_limits.set_y_axis_limits(
        axis=ax1,
    )
    ax2 = config.plot_size_config_nested.secondary_axis_limits.set_y_axis_limits(
        axis=ax2,
    )

    fixed_params_text: str = generate_fixed_parameters_text_from_dict(
        filters_dict=config.filter_key_value_pairs,
    )

    if fixed_params_text is not None:
        ax1.text(
            x=1.05,
            y=0.25,
            s=f"Fixed Parameters:\n{fixed_params_text}",
            transform=plt.gca().transAxes,
            fontsize=6,
            verticalalignment="top",
            bbox={
                "boxstyle": "round",
                "facecolor": "wheat",
                "alpha": 0.3,
            },
        )

    # Add info about the base model if available into the bottom left corner of the plot
    if config.base_model_model_partial_name is not None:
        ax1.text(
            x=0.01,
            y=0.01,
            s=f"{config.base_model_model_partial_name=}",
            transform=plt.gca().transAxes,
            fontsize=6,
            verticalalignment="bottom",
            horizontalalignment="left",
            bbox={
                "boxstyle": "round",
                "facecolor": "wheat",
                "alpha": 0.3,
            },
        )

    fig1.tight_layout()

    # Save the figure
    if config.plots_output_dir is not None:
        plot_name: str = f"local_estimates_by_model_seed_{config.plot_size_config_nested.y_range_description}"
        plot_save_path = pathlib.Path(
            config.plots_output_dir,
            "separate_seeds",
            f"{plot_name}.pdf",
        )

        if verbosity >= Verbosity.NORMAL:
            logger.info(
                msg=f"Saving plot to {plot_save_path = } ...",  # noqa: G004 - low overhead
            )

        plot_save_path.parent.mkdir(
            parents=True,
            exist_ok=True,
        )
        fig1.savefig(
            fname=plot_save_path,
        )

        if verbosity >= Verbosity.NORMAL:
            logger.info(
                msg=f"Saving plot to {plot_save_path = } DONE",  # noqa: G004 - low overhead
            )

    if config.show_plots:
        fig1.show()
