"""Base classes and utilities for plotting."""

import logging
from abc import ABC, abstractmethod
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm, PowerNorm, Normalize

from data_processor import (
    preprocess_data,
    subsample_and_aggregate_n,
    generate_sample_sizes,
    _pareto_frontier,
    get_n_schedule
)
from config import DEFAULT_PLOT_PARAMS


def setup_color_normalization(color_scale: str, values: np.ndarray):
    """Setup color normalization based on scale type."""
    if color_scale == "log":
        return LogNorm(values.min(), values.max())
    elif color_scale == "sqrt":
        return PowerNorm(gamma=0.5, vmin=values.min(), vmax=values.max())
    else:
        return Normalize(values.min(), values.max())


def get_schedule_points(y: np.ndarray, opt_flops: np.ndarray, sampling_prefill_flops: np.ndarray,
                       sampling_generation_flops: np.ndarray, return_ratio: bool = False,
                       n_smoothing: int = 1, cumulative: bool = False, schedule_type: str = "uniform"):
    """Generate points for plotting with optional ratio calculation."""
    n_runs, n_steps, total_samples = y.shape
    rng = np.random.default_rng(42)  # Fixed seed for reproducibility
    pts = []

    for j in range(1, total_samples + 1, 1):
        for i in range(1, n_steps + 1):
            n_vec = get_n_schedule(i, j, schedule_type)
            pts.append(subsample_and_aggregate_n(
                n_vec, y, opt_flops, sampling_prefill_flops,
                sampling_generation_flops, rng, return_ratio, n_smoothing
            ))

    return pts


class BasePlotter(ABC):
    """Base class for all plotters."""

    def __init__(self, **kwargs):
        """Initialize plotter with default parameters."""
        self.params = {**DEFAULT_PLOT_PARAMS, **kwargs}
        self.logger = logging.getLogger(self.__class__.__name__)

    def _validate_inputs(self, results: dict, baseline: dict|None = None):
        """Validate input data."""
        if not results:
            raise ValueError("Results dictionary cannot be empty")

        required_keys = ["flops_sampling_prefill_cache", "flops_sampling_generation"]
        for key in required_keys:
            if key not in results:
                raise ValueError(f"Missing required key in results: {key}")

    def _preprocess_data(self, results: dict, metric: tuple, threshold: float|None) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Common data preprocessing."""
        return preprocess_data(results, metric, threshold)

    def _setup_figure(self, title: str, figsize: tuple = (10, 8)) -> tuple:
        """Setup matplotlib figure and axes."""
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_title(title)
        return fig, ax

    @abstractmethod
    def plot(self, results: dict, baseline: dict|None = None,
             title: str = "", **kwargs) -> list:
        """Abstract method for plotting. Must be implemented by subclasses."""
        pass


class FrontierPlotter(BasePlotter):
    """Base class for frontier-based plots (Pareto, etc.)."""

    def _compute_frontier(self, xs: np.ndarray, ys: np.ndarray,
                         method: str = "basic") -> tuple:
        """Compute frontier points."""
        return _pareto_frontier(xs, ys, method=method)

    def _plot_frontier(self, ax, xs: np.ndarray, ys: np.ndarray,
                      label: str = "", color: str = "blue", **kwargs):
        """Plot frontier line."""
        ax.plot(xs, ys, label=label, color=color, **kwargs)

    def _plot_points(self, ax, xs: np.ndarray, ys: np.ndarray, colors: np.ndarray,
                     norm, cmap: str = "viridis", **kwargs):
        """Plot scatter points with color coding."""
        scatter = ax.scatter(xs, ys, c=colors, norm=norm, cmap=cmap, **kwargs)
        return scatter

    def _setup_axes(self, ax, xlabel: str = "FLOPs", ylabel: str = r"$\mathbb{E}[h(Y)]$"):
        """Setup common axes properties."""
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.set_xscale('log')
        ax.grid(True, alpha=0.3)


class HistogramPlotter(BasePlotter):
    """Base class for histogram-based plots."""

    def _setup_histogram_axes(self, ax, xlabel: str = "Value", ylabel: str = "Count"):
        """Setup axes for histogram plots."""
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.grid(True, alpha=0.3)


class AnalysisPlotter(BasePlotter):
    """Base class for analysis plots (ratio, breakdown, etc.)."""

    def _setup_analysis_axes(self, ax, xlabel: str = "X", ylabel: str = "Y"):
        """Setup axes for analysis plots."""
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.grid(True, alpha=0.3)


class PlotterMixin:
    """Mixin class providing common plotting utilities."""

    @staticmethod
    def get_sample_levels(total_samples: int,
                         custom_levels: Tuple[int, ...]|None = None) -> Tuple[int, ...]:
        """Get sample levels for plotting."""
        if custom_levels is not None:
            return custom_levels
        return generate_sample_sizes(total_samples)

    @staticmethod
    def format_flops(flops: float) -> str:
        """Format FLOPs value for display."""
        if flops >= 1e15:
            return f"{flops/1e15:.1f}P"
        elif flops >= 1e12:
            return f"{flops/1e12:.1f}T"
        elif flops >= 1e9:
            return f"{flops/1e9:.1f}G"
        elif flops >= 1e6:
            return f"{flops/1e6:.1f}M"
        elif flops >= 1e3:
            return f"{flops/1e3:.1f}K"
        else:
            return f"{flops:.0f}"

    def add_colorbar(self, fig, scatter, label: str = "Samples"):
        """Add colorbar to figure."""
        cbar = fig.colorbar(scatter)
        cbar.set_label(label)
        return cbar