"""Analysis engine to orchestrate plotting operations."""

import logging
import os
from typing import Any, Dict, List, Optional

import matplotlib.pyplot as plt

from config import (
    ATTACKS,
    DATASET_IDX,
    GROUP_BY,
    METRIC,
    MODELS,
    PLOT_OUTPUT_DIR,
)
from data_processor import fetch_data
from plotters import (
    AbsoluteBarChartPlotter,
    BarChartPlotter,
    ComparativeParetoPlotter,
    FlopsBreakdownPlotter,
    FlopsRatioPlotter,
    HistogramTwoPlotter,
    IdealRatioPlotter,
    MultiAttackNonCumulativePlotter,
    NonCumulativeParetoPlotter,
    OptimizationProgressPlotter,
    ParetoPlotter,
    RatioPlotter,
    RidgePlotter,
    RidgeSideBySidePlotter,
    StandardHistogramPlotter,
)
from src.io_utils import num_model_params


class AnalysisEngine:
    """Main engine for orchestrating analysis and plotting operations."""

    def __init__(self, output_dir: str|None = None):
        """Initialize the analysis engine.

        Parameters
        ----------
        output_dir : str, optional
            Directory to save plots to. If None, uses PLOT_OUTPUT_DIR from config.
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.output_dir = output_dir or PLOT_OUTPUT_DIR
        os.makedirs(self.output_dir, exist_ok=True)
        self._plotters = self._initialize_plotters()

    def _initialize_plotters(self) -> Dict[str, Any]:
        """Initialize all plotter instances."""
        return {
            'pareto': ParetoPlotter(),
            'non_cumulative_pareto': NonCumulativeParetoPlotter(),
            'flops_ratio': FlopsRatioPlotter(),
            'ideal_ratio': IdealRatioPlotter(),
            'flops_breakdown': FlopsBreakdownPlotter(),
            'histogram': StandardHistogramPlotter(),
            'histogram_2': HistogramTwoPlotter(),
            'ridge': RidgePlotter(),
            'ridge_side_by_side': RidgeSideBySidePlotter(),
            'ratio_plot': RatioPlotter(),
            'comparative_pareto': ComparativeParetoPlotter(),
            'multi_attack_non_cumulative_pareto': MultiAttackNonCumulativePlotter(),
            'optimization_progress': OptimizationProgressPlotter(),
            'bar_chart': BarChartPlotter(),
            'absolute_bar_chart': AbsoluteBarChartPlotter(),
        }

    def _save_plots(self, result: Any, analysis_type: str, model: str, atk_name: str, cfg: dict):
        """Save generated plots to files.

        Parameters
        ----------
        result : Any
            Plot result (figure or tuple of figures)
        analysis_type : str
            Type of analysis
        model : str
            Model identifier
        atk_name : str
            Attack name
        cfg : dict
            Attack configuration
        """
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)

        # Generate base filename
        title_suffix = cfg.get('title_suffix', '').replace(' ', '_').replace('(', '').replace(')', '')
        model_clean = model.replace('/', '_').replace('-', '_')

        # Different output location from before, but this is wanted for the new plots.
        base_filename = f"{analysis_type}/{model_clean}_{atk_name}_{title_suffix}".replace('__', '_')
        os.makedirs(os.path.join(self.output_dir, analysis_type), exist_ok=True)

        # Handle different result types
        if isinstance(result, list):
            # Multiple figures from bar chart plotter
            chart_names = ["asr_delta", "flops_efficiency", "speedup"]
            for i, fig in enumerate(result):
                if hasattr(fig, 'savefig'):
                    chart_suffix = chart_names[i] if i < len(chart_names) else f"chart_{i}"
                    filename = f"{base_filename}_{chart_suffix}.pdf"
                    filepath = os.path.join(self.output_dir, filename)
                    fig.savefig(filepath, dpi=300, bbox_inches='tight')
                    self.logger.info(f"Plot saved to {filepath}")
        elif isinstance(result, tuple) and len(result) == 2:
            if isinstance(result[0], tuple) and len(result[0]) == 2:
                # Multiple plots (e.g., pareto with/without threshold)
                fig1, ax1 = result[0]
                fig2, ax2 = result[1]

                # Save first plot
                filename1 = f"{base_filename}_with_threshold.pdf"
                filepath1 = os.path.join(self.output_dir, filename1)
                fig1.savefig(filepath1, dpi=300, bbox_inches='tight')
                self.logger.info(f"Plot saved to {filepath1}")

                # Save second plot
                filename2 = f"{base_filename}_no_threshold.pdf"
                filepath2 = os.path.join(self.output_dir, filename2)
                fig2.savefig(filepath2, dpi=300, bbox_inches='tight')
                self.logger.info(f"Plot saved to {filepath2}")
            elif isinstance(result[0], list) and isinstance(result[1], list):
                # Multiple bar chart figures for threshold/no-threshold versions
                chart_names = ["asr_delta", "flops_efficiency", "speedup"]
                figs_with_threshold, figs_no_threshold = result

                # Save figures with threshold (0.5)
                for i, fig in enumerate(figs_with_threshold):
                    if hasattr(fig, 'savefig'):
                        chart_suffix = chart_names[i] if i < len(chart_names) else f"chart_{i}"
                        filename = f"{base_filename}_{chart_suffix}_t=0.5.pdf"
                        filepath = os.path.join(self.output_dir, filename)
                        fig.savefig(filepath, dpi=300, bbox_inches='tight')
                        self.logger.info(f"Plot saved to {filepath}")

                # Save figures without threshold (None)
                for i, fig in enumerate(figs_no_threshold):
                    if hasattr(fig, 'savefig'):
                        chart_suffix = chart_names[i] if i < len(chart_names) else f"chart_{i}"
                        filename = f"{base_filename}_{chart_suffix}.pdf"
                        filepath = os.path.join(self.output_dir, filename)
                        fig.savefig(filepath, dpi=300, bbox_inches='tight')
                        self.logger.info(f"Plot saved to {filepath}")
            else:
                # Single plot with (fig, ax)
                fig, ax = result
                filename = f"{base_filename}.pdf"
                filepath = os.path.join(self.output_dir, filename)
                fig.savefig(filepath, dpi=300, bbox_inches='tight')
                self.logger.info(f"Plot saved to {filepath}")
        elif hasattr(result, 'savefig'):
            # Single figure object
            filename = f"{base_filename}.pdf"
            filepath = os.path.join(self.output_dir, filename)
            result.savefig(filepath, dpi=300, bbox_inches='tight')  # type: ignore
            self.logger.info(f"Plot saved to {filepath}")
        plt.close()

    def _save_comparative_plots(self, result: Any, analysis_type: str, model: str, model_title: str, threshold: Optional[float]):
        """Save comparative analysis plots to files.

        Parameters
        ----------
        result : Any
            Plot result (figure or tuple of figures)
        analysis_type : str
            Type of comparative analysis
        model : str
            Model identifier
        model_title : str
            Human-readable model title
        threshold : float, optional
            Threshold value used in analysis
        """
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)

        # Generate base filename
        model_clean = model.replace('/', '_').replace('-', '_')
        base_filename = f"{analysis_type}/{model_clean}"
        os.makedirs(os.path.join(self.output_dir, analysis_type), exist_ok=True)

        # Handle different result types
        if isinstance(result, tuple) and len(result) == 2:
            if isinstance(result[0], tuple) and len(result[0]) == 2:
                # Multiple plots (e.g., comparative_pareto with/without threshold)
                fig1, ax1 = result[0]
                fig2, ax2 = result[1]

                # Save first plot (no threshold)
                filename1 = f"{base_filename}_no_threshold.pdf"
                filepath1 = os.path.join(self.output_dir, filename1)
                fig1.savefig(filepath1, dpi=300, bbox_inches='tight')
                self.logger.info(f"Comparative plot saved to {filepath1}")

                # Save second plot (with threshold)
                filename2 = f"{base_filename}_with_threshold.pdf"
                filepath2 = os.path.join(self.output_dir, filename2)
                fig2.savefig(filepath2, dpi=300, bbox_inches='tight')
                self.logger.info(f"Comparative plot saved to {filepath2}")
            else:
                # Single plot with (fig, ax)
                fig, ax = result
                threshold_suffix = f"_threshold_{threshold}" if threshold is not None else ""
                filename = f"{base_filename}{threshold_suffix}.pdf"
                filepath = os.path.join(self.output_dir, filename)
                fig.savefig(filepath, dpi=300, bbox_inches='tight', transparent=True)
                self.logger.info(f"Comparative plot saved to {filepath}")
        elif hasattr(result, 'savefig'):
            # Single figure object
            threshold_suffix = f"_threshold_{threshold}" if threshold is not None else ""
            filename = f"{base_filename}{threshold_suffix}.pdf"
            filepath = os.path.join(self.output_dir, filename)
            # Type assertion - we know result has savefig method from hasattr check
            result.savefig(filepath, dpi=300, bbox_inches='tight')  # type: ignore
            self.logger.info(f"Comparative plot saved to {filepath}")

    def run_analysis(self, model: str, model_title: str, atk_name: str,
                    cfg: dict, analysis_type: str = "pareto", schedule: str = 'end') -> Optional[Any]:
        """
        Run analysis for a single model-attack combination.

        Parameters
        ----------
        model : str
            Model identifier
        model_title : str
            Human-readable model title
        atk_name : str
            Attack name
        cfg : dict
            Attack configuration
        analysis_type : str
            Type of analysis to run

        Returns
        -------
        Figure or tuple of figures, or None if analysis failed
        """
        self.logger.info(f"{analysis_type.title()} Analysis: {atk_name} {cfg.get('title_suffix', '')}")

        try:
            # Get the appropriate plotter
            if analysis_type not in self._plotters:
                raise ValueError(f"Unknown analysis type: {analysis_type}")

            plotter = self._plotters[analysis_type]

            # Fetch sampled data
            sampled_data = fetch_data(
                model,
                cfg.get("attack_override", atk_name),
                cfg["sample_params"](),
                DATASET_IDX,
                GROUP_BY
            )

            # Apply post-processing if needed
            if post := cfg.get("postprocess"):
                post(sampled_data, METRIC)


            # Handle different analysis types
            if analysis_type in ["histogram", "histogram_2", "ridge", "ridge_side_by_side"]:
                # Histogram-based plots don't need baseline
                result = self._run_histogram_analysis(
                    plotter, analysis_type, sampled_data, model_title, atk_name, cfg, schedule
                )
            else:
                # Other plots need baseline data
                baseline_data = None
                baseline_attack = cfg.get("baseline_attack", atk_name)
                baseline_data = fetch_data(
                    model,
                    baseline_attack,
                    cfg["baseline_params"](),
                    DATASET_IDX,
                    GROUP_BY
                )

                result = self._run_standard_analysis(
                    plotter, analysis_type, sampled_data, baseline_data,
                    model_title, atk_name, cfg, schedule
                )

            # Save the generated plots
            if result:
                self._save_plots(result, analysis_type, model, atk_name, cfg)

            return result

        except Exception as e:
            self.logger.error(f"Analysis failed for {atk_name}: {e}")
            return None

    def _run_histogram_analysis(self, plotter, analysis_type: str, sampled_data: dict,
                              model_title: str, atk_name: str, cfg: dict, schedule: str) -> Any:
        """Run histogram-based analysis."""
        if analysis_type == "histogram":
            return plotter.plot(sampled_data, model_title, atk_name, cfg, threshold=None, schedule=schedule)
        elif analysis_type == "histogram_2":
            return plotter.plot(sampled_data, model_title, cfg, threshold=None, schedule=schedule)
        elif analysis_type == "ridge":
            return plotter.plot(sampled_data, model_title, cfg, threshold=None, schedule=schedule)
        elif analysis_type == "ridge_side_by_side":
            return plotter.plot(sampled_data, model_title, cfg, threshold=None, schedule=schedule)
        else:
            raise ValueError(f"Unknown histogram analysis type: {analysis_type}")

    def _run_standard_analysis(self, plotter, analysis_type: str, sampled_data: dict,
                             baseline_data: dict, model_title: str, atk_name: str,
                             cfg: dict, schedule: str) -> Any:
        """Run standard analysis with baseline comparison."""
        title = f"{model_title} {cfg['title_suffix']}"

        common_params = {
            'metric': METRIC,
            'color_scale': 'sqrt',
        }

        if analysis_type == "pareto":
            # Run both with and without threshold
            fig1, ax1 = plotter.plot(
                sampled_data, baseline_data, title=title,
                cumulative=cfg["cumulative"], threshold=0.5, **common_params, schedule=schedule
            )
            fig2, ax2 = plotter.plot(
                sampled_data, baseline_data, title=title,
                cumulative=cfg["cumulative"], threshold=None, **common_params, schedule=schedule
            )
            return (fig1, ax1), (fig2, ax2)

        elif analysis_type == "non_cumulative_pareto":
            # Run both with and without threshold
            fig1, ax1 = plotter.plot(
                sampled_data, baseline_data, title=title,
                threshold=0.5, **common_params, schedule=schedule
            )
            fig2, ax2 = plotter.plot(
                sampled_data, baseline_data, title=title,
                threshold=None, **common_params, schedule=schedule
            )
            return (fig1, ax1), (fig2, ax2)

        elif analysis_type == "flops_ratio":
            return plotter.plot(
                sampled_data, baseline_data,
                title=f"{title} FLOPs Ratio",
                cumulative=cfg["cumulative"], threshold=None, **common_params, schedule=schedule
            )

        elif analysis_type == "ideal_ratio":
            return plotter.plot(
                sampled_data, baseline_data,
                title=f"{title} Ideal Ratio",
                cumulative=cfg["cumulative"], threshold=None,
                metric=METRIC, schedule=schedule
            )

        elif analysis_type == "flops_breakdown":
            return plotter.plot(
                sampled_data, baseline_data,
                title=f"{title} FLOPs Breakdown",
                cumulative=cfg["cumulative"], threshold=None, **common_params, schedule=schedule
            )

        elif analysis_type == "optimization_progress":
            return plotter.plot(
                sampled_data,
                title=f"{title} Optimization Progress",
                cumulative=False, metric=METRIC, threshold=None, schedule=schedule
            )

        elif analysis_type == "ratio_plot":
            return plotter.plot(
                sampled_data, model_title, cfg,
                threshold=None, metric=METRIC
            )

        elif analysis_type == "bar_chart":
            # Run both with and without threshold
            figs1 = plotter.plot(
                sampled_data, baseline_data,
                title=title,
                threshold=0.5, **common_params, schedule=schedule
            )
            figs2 = plotter.plot(
                sampled_data, baseline_data,
                title=title,
                threshold=None, **common_params, schedule=schedule
            )
            return figs1, figs2
        elif analysis_type == "absolute_bar_chart":
            figs1 = plotter.plot(
                sampled_data, baseline_data,
                title=title,
                threshold=0.5, **common_params, schedule=schedule
            )
            figs2 = plotter.plot(
                sampled_data, baseline_data,
                title=title,
                threshold=None, **common_params, schedule=schedule
            )
            return figs1, figs2

        else:
            raise ValueError(f"Unknown standard analysis type: {analysis_type}")

    def run_comparative_analysis(self, model: str, model_title: str,
                               analysis_type: str = "comparative_pareto",
                               threshold: Optional[float] = None,
                               schedule: str = 'end') -> Optional[Any]:
        """
        Run comparative analysis across multiple attacks for a single model.

        Parameters
        ----------
        model : str
            Model identifier
        model_title : str
            Human-readable model title
        analysis_type : str
            Type of comparative analysis
        threshold : float, optional
            Threshold for binary classification

        Returns
        -------
        Figure or tuple of figures, or None if analysis failed
        """
        self.logger.info(f"{analysis_type.title()} Analysis: {model_title}")

        try:
            # Get the appropriate plotter
            if analysis_type not in self._plotters:
                raise ValueError(f"Unknown analysis type: {analysis_type}")

            plotter = self._plotters[analysis_type]

            # Collect data from all attacks for this model
            attacks_data = {}

            for atk_name, cfg in ATTACKS:
                try:
                    # Fetch attack data
                    sampled_data = fetch_data(
                        model,
                        cfg.get("attack_override", atk_name),
                        cfg["sample_params"](),
                        DATASET_IDX,
                        GROUP_BY
                    )

                    # Apply post-processing if needed
                    if post := cfg.get("postprocess"):
                        post(sampled_data, METRIC)

                    # Use title_suffix as key to distinguish between different configs
                    config_key = cfg['title_suffix']
                    attacks_data[config_key] = (sampled_data, cfg)

                except Exception as e:
                    self.logger.warning(f"Could not load data for {atk_name} ({cfg.get('title_suffix', 'unknown config')}): {e}")
                    continue

            if not attacks_data:
                self.logger.error(f"No attack data loaded for model {model}")
                return None

            # Generate comparative plot
            if analysis_type == "comparative_pareto":
                # Run both with and without threshold
                baseline_attacks = {"gcg", "beast", "pair", "autodan", "gcg_reinforce"}

                fig1, ax1 = plotter.plot(
                    model=model,
                    model_title=model_title,
                    attacks_data=attacks_data,
                    title=f"{model_title}",
                    metric=METRIC,
                    flops_per_step_fns=flops_per_step_fns,
                    threshold=None,
                    baseline_attacks=baseline_attacks,
                    schedule=schedule,
                )

                fig2, ax2 = plotter.plot(
                    model=model,
                    model_title=model_title,
                    attacks_data=attacks_data,
                    title=f"{model_title} (threshold=0.5)",
                    metric=METRIC,
                    flops_per_step_fns=flops_per_step_fns,
                    threshold=0.5,
                    baseline_attacks=baseline_attacks,
                    schedule=schedule,
                )

                result = (fig1, ax1), (fig2, ax2)

            elif analysis_type == "multi_attack_non_cumulative_pareto":
                result = plotter.plot(
                    model=model,
                    model_title=model_title,
                    attacks_data=attacks_data,
                    title=f"{model_title} Multi-Attack Non-Cumulative",
                    metric=METRIC,
                    flops_per_step_fns=flops_per_step_fns,
                    threshold=threshold,
                    schedule=schedule,
                )

            else:
                raise ValueError(f"Unknown comparative analysis type: {analysis_type}")

            # Save the generated comparative plots
            if result:
                self._save_comparative_plots(result, analysis_type, model, model_title, threshold)

            return result

        except Exception as e:
            self.logger.error(f"Comparative analysis failed for {model}: {e}")
            return None

    def run_full_analysis(self, analysis_types: Optional[List[str]] = None,
                          fail_on_error: bool = False, schedule: str = 'end') -> Dict[str, List[Any]]:
        """
        Run full analysis across all models and attacks.

        Parameters
        ----------
        analysis_types : list of str, optional
            List of analysis types to run. If None, runs all available types.
        fail_on_error : bool
            Whether to raise exceptions on individual analysis failures

        Returns
        -------
        Dict mapping analysis types to lists of results
        """
        if analysis_types is None:
            analysis_types = list(self._plotters.keys())

        results = {analysis_type: [] for analysis_type in analysis_types}

        for analysis_type in analysis_types:
            self.logger.info("\n" + "="*80)
            self.logger.info(f"GENERATING {analysis_type.upper().replace('_', ' ')} PLOTS")
            self.logger.info("="*80)

            try:
                if analysis_type in ["comparative_pareto", "multi_attack_non_cumulative_pareto"]:
                    # Comparative analyses run across all attacks for each model
                    for model_key, model_title in MODELS.items():
                        self.logger.info(f"Model: {model_key}")
                        result = self.run_comparative_analysis(model_key, model_title, analysis_type, None, schedule)
                        if result:
                            results[analysis_type].append(result)
                else:
                    # Standard analyses run for each model-attack combination
                    for model_key, model_title in MODELS.items():
                        self.logger.info(f"Model: {model_key}")
                        for atk_name, atk_cfg in ATTACKS:
                            result = self.run_analysis(model_key, model_title, atk_name, atk_cfg, analysis_type, schedule)
                            if result:
                                results[analysis_type].append(result)

            except Exception as e:
                self.logger.error(f"Failed to run {analysis_type} analysis: {e}")
                if fail_on_error:
                    raise

        return results