"""
Base experiment class and experiment runner.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List, Callable
from pathlib import Path
import time
import os

import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed

from ..utils.config import Config, get_config
from ..utils.logging import get_logger, ExperimentLogger
from ..utils.io import save_results, save_checkpoint, load_checkpoint, find_latest_checkpoint, ensure_dir


class BaseExperiment(ABC):
    """
    Base class for all experiments.

    Subclasses must implement:
    - name: Experiment name property
    - run(): Main experiment logic
    - plot_results(): Generate figures
    """

    def __init__(
        self,
        config: Optional[Config] = None,
        output_dir: Optional[str] = None,
        random_state: Optional[int] = None
    ):
        """
        Initialize experiment.

        Args:
            config: Configuration (default: load from config.yml)
            output_dir: Output directory (default: from config)
            random_state: Random seed (default: from config)
        """
        self.config = config or get_config()
        self.output_dir = Path(output_dir or self.config.output.base_dir)
        self.random_state = random_state or self.config.random_seed
        self.rng = np.random.default_rng(self.random_state)
        self.logger = get_logger(self.name)
        self.results: Dict[str, Any] = {}
        self._start_time: Optional[float] = None

    @property
    @abstractmethod
    def name(self) -> str:
        """Return experiment name."""
        pass

    @abstractmethod
    def run(self) -> Dict[str, Any]:
        """
        Execute the experiment.

        Returns:
            Results dictionary
        """
        pass

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """
        Generate figures from results.

        Args:
            output_dir: Directory to save figures (default: config figures_dir)
        """
        pass

    def save_results(
        self,
        filename: Optional[str] = None,
        format: str = 'json'
    ) -> Path:
        """
        Save experiment results.

        Args:
            filename: Output filename
            format: 'json' or 'pickle'

        Returns:
            Path to saved file
        """
        if filename is None:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            filename = f"{self.name}_{timestamp}.{format}"

        output_path = self.output_dir / filename
        save_results(self.results, output_path, format=format)
        self.logger.info(f"Results saved to {output_path}")

        return output_path

    def checkpoint(
        self,
        state: Dict[str, Any],
        iteration: int
    ) -> None:
        """
        Save a checkpoint.

        Args:
            state: State to save
            iteration: Current iteration
        """
        checkpoint_dir = Path(self.config.output.checkpoints_dir)
        save_checkpoint(state, checkpoint_dir, self.name, iteration)

    def load_checkpoint(self) -> Optional[Dict[str, Any]]:
        """
        Load most recent checkpoint.

        Returns:
            Checkpoint state or None
        """
        checkpoint_dir = Path(self.config.output.checkpoints_dir)
        checkpoint_path = find_latest_checkpoint(checkpoint_dir, self.name)

        if checkpoint_path:
            self.logger.info(f"Loading checkpoint from {checkpoint_path}")
            return load_checkpoint(checkpoint_path)

        return None

    def get_progress_bar(
        self,
        iterable,
        desc: str = "",
        total: Optional[int] = None
    ):
        """
        Get progress bar if enabled.

        Args:
            iterable: Iterable to wrap
            desc: Description
            total: Total count

        Returns:
            tqdm progress bar or original iterable
        """
        if self.config.computation.show_progress:
            return tqdm(iterable, desc=desc, total=total)
        return iterable

    def get_n_jobs(self) -> int:
        """Get number of parallel jobs to use."""
        n_jobs = self.config.computation.n_jobs
        if n_jobs is None:
            return max(1, os.cpu_count() - 1)
        return n_jobs

    def parallel_map(
        self,
        func: Callable,
        items: List[Any],
        desc: str = "Processing",
        n_jobs: Optional[int] = None
    ) -> List[Any]:
        """
        Execute function in parallel with progress bar.

        Args:
            func: Function to apply to each item
            items: List of items to process
            desc: Description for progress bar
            n_jobs: Number of parallel jobs (default: from config)

        Returns:
            List of results
        """
        if n_jobs is None:
            n_jobs = self.get_n_jobs()

        if n_jobs == 1:
            # Sequential execution with progress bar
            results = []
            for item in self.get_progress_bar(items, desc=desc, total=len(items)):
                results.append(func(item))
            return results

        # Parallel execution with progress bar
        if self.config.computation.show_progress:
            results = Parallel(n_jobs=n_jobs, backend='loky')(
                delayed(func)(item) for item in tqdm(items, desc=desc, total=len(items))
            )
        else:
            results = Parallel(n_jobs=n_jobs, backend='loky')(
                delayed(func)(item) for item in items
            )

        return results

    def execute(self) -> Dict[str, Any]:
        """
        Execute experiment with logging and timing.

        Returns:
            Results dictionary
        """
        self._start_time = time.time()

        with ExperimentLogger(self.name, self._get_config_dict(), self.logger):
            try:
                self.results = self.run()
                self.results['duration_seconds'] = time.time() - self._start_time
                self.results['success'] = True
            except Exception as e:
                self.logger.error(f"Experiment failed: {e}")
                self.results['error'] = str(e)
                self.results['success'] = False
                raise

        return self.results

    def _get_config_dict(self) -> Dict[str, Any]:
        """Get experiment-relevant config as dict."""
        return {
            'name': self.name,
            'random_seed': self.random_state,
        }


class ExperimentRunner:
    """
    Runner for executing multiple experiments.
    """

    def __init__(
        self,
        config: Optional[Config] = None,
        output_dir: Optional[str] = None
    ):
        """
        Initialize runner.

        Args:
            config: Configuration
            output_dir: Output directory
        """
        self.config = config or get_config()
        self.output_dir = Path(output_dir or self.config.output.base_dir)
        self.logger = get_logger('runner')
        self.results: Dict[str, Dict] = {}

    def run_experiment(
        self,
        experiment: BaseExperiment,
        save: bool = True,
        plot: bool = True
    ) -> Dict[str, Any]:
        """
        Run a single experiment.

        Args:
            experiment: Experiment to run
            save: Whether to save results
            plot: Whether to generate plots

        Returns:
            Results dictionary
        """
        self.logger.info(f"Running experiment: {experiment.name}")

        results = experiment.execute()
        self.results[experiment.name] = results

        if save:
            experiment.save_results()

        if plot:
            try:
                figures_dir = Path(self.config.output.figures_dir)
                ensure_dir(figures_dir)
                experiment.plot_results(str(figures_dir))
            except Exception as e:
                self.logger.warning(f"Failed to generate plots: {e}")

        return results

    def run_all(
        self,
        experiments: List[BaseExperiment],
        save: bool = True,
        plot: bool = True
    ) -> Dict[str, Dict]:
        """
        Run multiple experiments.

        Args:
            experiments: List of experiments
            save: Whether to save results
            plot: Whether to generate plots

        Returns:
            Dict mapping experiment name to results
        """
        for experiment in experiments:
            try:
                self.run_experiment(experiment, save=save, plot=plot)
            except Exception as e:
                self.logger.error(f"Experiment {experiment.name} failed: {e}")
                continue

        return self.results

    def generate_summary(self) -> Dict[str, Any]:
        """
        Generate summary of all experiment results.

        Returns:
            Summary dictionary
        """
        summary = {
            'n_experiments': len(self.results),
            'successful': sum(1 for r in self.results.values() if r.get('success', False)),
            'failed': sum(1 for r in self.results.values() if not r.get('success', True)),
            'total_duration': sum(r.get('duration_seconds', 0) for r in self.results.values()),
            'experiments': {}
        }

        for name, results in self.results.items():
            summary['experiments'][name] = {
                'success': results.get('success', False),
                'duration': results.get('duration_seconds', 0),
            }

        return summary
