"""
Experiment 2: Tightness of Upper and Lower Bounds.

Verifies that the theoretical bounds are tight up to constant factors.
"""

from __future__ import annotations

from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
import numpy as np
from scipy.stats import linregress

from .base import BaseExperiment
from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension
from ..generators.factory import generate_dag
from ..algorithms.pc import PCAlgorithm
from ..metrics.shd import structural_hamming_distance


@dataclass
class BoundsTask:
    """Task for processing a single graph in bounds experiment."""
    graph_type: str
    family_name: str
    d: int
    seed: int
    er_p: Optional[float]
    beta_range: Tuple[float, float]
    sigma_range: Tuple[float, float]
    n_values: List[int]
    n_trials: int
    success_threshold: float
    pc_alpha: float
    sign_distribution: str = 'positive'


def _process_bounds_task(task: BoundsTask) -> Optional[Dict[str, Any]]:
    """Process a single graph for bounds experiment."""
    # Generate DAG
    if task.graph_type == 'erdos_renyi':
        dag = generate_dag(task.graph_type, task.d, random_state=task.seed, p=task.er_p)
    else:
        dag = generate_dag(task.graph_type, task.d, random_state=task.seed)

    # Skip graphs with no edges
    if dag.num_edges() == 0:
        return None

    # Generate SEM with specified beta range
    sem = LinearGaussianSEM.random(
        dag,
        beta_range=task.beta_range,
        sigma_range=task.sigma_range,
        random_state=task.seed,
        sign_distribution=task.sign_distribution
    )

    # Compute Fisher dimension
    fd_result = compute_fisher_dimension(dag, sem)
    fisher_dim = fd_result.fisher_dimension

    # Find empirical sample complexity
    n_star = _find_sample_complexity_bounds(
        dag, sem, task.n_values, task.n_trials,
        task.success_threshold, task.pc_alpha, task.seed
    )

    return {
        'family_name': task.family_name,
        'graph_type': task.graph_type,
        'fisher_dim': fisher_dim,
        'n_star': n_star,
    }


def _find_sample_complexity_bounds(
    dag: DAG,
    sem: LinearGaussianSEM,
    n_values: List[int],
    n_trials: int,
    success_threshold: float,
    pc_alpha: float,
    base_seed: int
) -> float:
    """Find minimum sample size for successful recovery."""
    true_cpdag = CPDAG.from_dag(dag)
    pc = PCAlgorithm(alpha=pc_alpha)
    rng = np.random.default_rng(base_seed)

    for n in n_values:
        successes = 0
        for trial in range(n_trials):
            seed = rng.integers(0, 2**31)
            X = sem.sample(n, random_state=seed)
            try:
                result = pc.fit(X)
                shd = structural_hamming_distance(true_cpdag, result.cpdag)
                if shd == 0:
                    successes += 1
            except Exception:
                continue

        if successes / n_trials >= success_threshold:
            return float(n)

    return float(n_values[-1])


class Experiment2Bounds(BaseExperiment):
    """
    Experiment 2: Tightness of Upper and Lower Bounds.

    Setup:
    - For d=10 nodes, use 3 graph types × 3 beta families (9 combinations)
    - Compute theoretical predictions and empirical sample complexity
    - Fit constant C per (graph_type, family) combination

    Expected outcome: Predicted sample complexity matches empirical within factor of 2.
    """

    @property
    def name(self) -> str:
        return "exp2_bounds"

    def run(self) -> Dict[str, Any]:
        """Execute the experiment."""
        exp_config = self.config.experiments.exp2_bounds

        d = exp_config.d
        graphs_per_family = exp_config.graphs_per_family
        n_values = exp_config.n_values
        n_trials = exp_config.n_trials
        success_threshold = exp_config.success_threshold

        # 3 graph types for multi-panel comparison
        graph_types = ['tree', 'chain', 'erdos_renyi']

        # Define beta families with different Fisher dimension ranges
        beta_families = {
            'low_F': (0.6, 0.9),      # Strong coefficients -> low Fisher dim
            'medium_F': (0.4, 0.6),
            'high_F': (0.25, 0.4),    # Weaker -> higher Fisher dim
        }

        self.logger.info(f"Running Experiment 2 with d={d}, {graphs_per_family} graphs per cell")
        self.logger.info(f"Graph types: {graph_types}")
        self.logger.info(f"Beta families: {list(beta_families.keys())}")
        self.logger.info(f"Using {self.get_n_jobs()} parallel workers")

        results = {
            'config': {
                'd': d,
                'graphs_per_family': graphs_per_family,
                'n_values': n_values,
                'n_trials': n_trials,
                'success_threshold': success_threshold,
                'graph_types': graph_types,
                'beta_families': {k: v for k, v in beta_families.items()},
            },
            'by_graph_type': {
                gt: {
                    'families': {
                        fn: {'fisher_dims': [], 'empirical_n': [], 'predicted': []}
                        for fn in beta_families
                    }
                }
                for gt in graph_types
            },
        }

        # Create tasks for all (graph_type, family) combinations
        tasks = []
        for gt in graph_types:
            for family_name, beta_range in beta_families.items():
                for i in range(graphs_per_family):
                    seed = self.rng.integers(0, 2**31)
                    er_p = 0.2 if gt == 'erdos_renyi' else None
                    task = BoundsTask(
                        graph_type=gt,
                        family_name=family_name,
                        d=d,
                        seed=seed,
                        er_p=er_p,
                        beta_range=beta_range,
                        sigma_range=(self.config.sem.sigma_min, self.config.sem.sigma_max),
                        n_values=n_values,
                        n_trials=n_trials,
                        success_threshold=success_threshold,
                        pc_alpha=self.config.pc.alpha,
                        sign_distribution=self.config.sem.sign_distribution,
                    )
                    tasks.append(task)

        # Process all tasks in parallel
        task_results = self.parallel_map(
            _process_bounds_task,
            tasks,
            desc="Processing graphs"
        )

        # Collect results by (graph_type, family)
        for result in task_results:
            if result is None:
                continue

            gt = result['graph_type']
            fn = result['family_name']
            results['by_graph_type'][gt]['families'][fn]['fisher_dims'].append(result['fisher_dim'])
            results['by_graph_type'][gt]['families'][fn]['empirical_n'].append(result['n_star'])

        # Theory: n = C * F([G]) * log(d/δ)
        delta = 1.0 - success_threshold
        log_factor = np.log(d) + np.log(1.0 / delta)
        max_n = max(n_values)

        # Fit C and compute stats for each (graph_type, family) cell
        stats_by_cell = {}
        for gt in graph_types:
            stats_by_cell[gt] = {}
            for fn, cell_data in results['by_graph_type'][gt]['families'].items():
                fisher_dims = np.array(cell_data['fisher_dims'])
                empirical = np.array(cell_data['empirical_n'])

                # Raw predictions: F([G]) * log(d/δ)
                pred_raw = fisher_dims * log_factor

                # Fit C for this cell
                valid = (np.isfinite(pred_raw) & np.isfinite(empirical) &
                         (pred_raw > 0) & (empirical < max_n))
                if np.sum(valid) > 0:
                    ratios_for_fit = empirical[valid] / pred_raw[valid]
                    C_cell = float(np.median(ratios_for_fit))
                else:
                    C_cell = 1.0

                # Predictions with cell-specific C
                predicted = C_cell * pred_raw
                cell_data['predicted'] = predicted.tolist()
                cell_data['fitted_C'] = C_cell

                # Calculate within_factor_2
                ratios = []
                for p, e in zip(predicted, empirical):
                    if p > 0 and np.isfinite(p) and np.isfinite(e) and e < max_n:
                        ratios.append(e / p)

                within_factor_2 = np.mean([0.5 <= r <= 2.0 for r in ratios]) if ratios else float('nan')
                cell_data['within_factor_2'] = within_factor_2

                stats_by_cell[gt][fn] = {
                    'fitted_C': C_cell,
                    'within_factor_2': within_factor_2,
                    'n_graphs': len(fisher_dims),
                }

                self.logger.info(f"{gt}/{fn}: C={C_cell:.1f}, within_2x={within_factor_2:.0%}")

        results['stats_by_cell'] = stats_by_cell

        # Overall metrics
        all_within_2x = []
        for gt_stats in stats_by_cell.values():
            for cell_stats in gt_stats.values():
                if not np.isnan(cell_stats['within_factor_2']):
                    all_within_2x.append(cell_stats['within_factor_2'])

        results['overall'] = {
            'mean_within_factor_2': np.mean(all_within_2x) if all_within_2x else float('nan'),
        }

        results['meets_expectation'] = results['overall'].get('mean_within_factor_2', 0) > exp_config.expected_within_factor

        self.logger.info(f"Mean within factor 2 across all cells: {results['overall']['mean_within_factor_2']:.1%}")
        self.logger.info(f"Meets expectation: {results['meets_expectation']}")

        return results

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """Generate experiment figures."""
        from ..utils.visualization import plot_exp2_grid, save_figure
        from pathlib import Path

        if not self.results:
            self.logger.warning("No results to plot")
            return

        output_dir = Path(output_dir or self.config.output.figures_dir)
        max_n = max(self.results['config']['n_values'])

        fig, axes = plot_exp2_grid(
            self.results['by_graph_type'],
            self.results['stats_by_cell'],
            max_n=max_n,
        )

        save_figure(fig, output_dir / f"{self.name}_scatter")
        self.logger.info(f"Saved figure to {output_dir}")
