import matplotlib as mpl
import numpy as np
import seaborn as sns
from pathlib import Path
from typing import Union, List


class BasePlotter:
    """
    A base class for plotters to ensure consistent styling.
    """

    def __init__(
        self,
        output_dir="plots/mixture_analysis",
        style="whitegrid",
        context="paper",
        font_scale=2,
        palette="tab20",
        y_name: Union[str, List[str]] = "Y",
    ):
        """
        Initializes the plotter and sets the visual style.

        Args:
            output_dir (str): The directory where plots will be saved.
            style (str): The seaborn style to apply (e.g., 'whitegrid', 'darkgrid').
            context (str): The seaborn context (e.g., 'paper', 'notebook', 'talk').
            font_scale (float): Scaling factor for font sizes.
            palette (str or list): The color palette for components. Can be a
                                   matplotlib colormap name or a list of colors.
            y_name (Union[str, List[str]]): The name of the target variable.
                For 1D targets, this should be a string (e.g., "Yield").
                For 2D targets, this can be a single string to be used as a base name
                (e.g., "Position", resulting in "Position Dim 1") or a list of two
                strings for individual axis labels (e.g., ["Latitude", "Longitude"]).
        """
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Set a consistent, clean theme for all plots.
        sns.set_theme(style=style, context=context, font_scale=font_scale)
        mpl.rcParams.update(
            {
                "font.family": "biolinum",  # Use a serif font
                "font.serif": ["biolinum"],
                "text.usetex": True,  # Tell matplotlib to use a LaTeX engine for text rendering
                # "pgf.rcfonts": False,  # Don't setup fonts from rc parameters
                "text.latex.preamble": (  # Preamble for the LaTeX engine
                    r"\usepackage{mathrsfs}"
                    r"\usepackage[T1]{fontenc}"
                    r"\usepackage{libertine}"
                    r"\usepackage{amsmath}"
                    r"\usepackage{eulervm}"
                ),
            }
        )
        self.palette = palette
        self.name_mapping = {
            "y_variable": y_name,
        }

    def _get_colors(self, n_colors):
        if isinstance(self.palette, list):
            return [self.palette[i % len(self.palette)] for i in range(n_colors)]
        cmap = mpl.colormaps.get_cmap(self.palette)
        return cmap(np.linspace(0, 1, n_colors))
