"""
Experiment 1: Correlation between Fisher dimension and sample complexity.

Verifies that F([G]) predicts the empirical sample complexity of causal discovery.
"""

from __future__ import annotations

from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
import numpy as np
from scipy.stats import pearsonr, spearmanr

from .base import BaseExperiment
from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension
from ..generators.factory import generate_dag
from ..algorithms.pc import PCAlgorithm
from ..metrics.shd import structural_hamming_distance


@dataclass
class GraphTask:
    """Task for processing a single graph."""
    family: str
    d: int
    seed: int
    er_p: Optional[float]
    beta_range: Tuple[float, float]
    sigma_range: Tuple[float, float]
    n_values: List[int]
    n_trials: int
    success_threshold: float
    pc_alpha: float
    sign_distribution: str = 'positive'


def _process_single_graph(task: GraphTask) -> Dict[str, Any]:
    """Process a single graph (worker function for parallel execution)."""
    # Generate DAG
    if task.family == 'erdos_renyi':
        dag = generate_dag(task.family, task.d, random_state=task.seed, p=task.er_p)
    else:
        dag = generate_dag(task.family, task.d, random_state=task.seed)

    # Skip empty graphs
    if dag.num_edges() == 0:
        return None

    # Generate SEM
    sem = LinearGaussianSEM.random(
        dag,
        beta_range=task.beta_range,
        sigma_range=task.sigma_range,
        random_state=task.seed,
        sign_distribution=task.sign_distribution
    )

    # Compute Fisher dimension
    fd_result = compute_fisher_dimension(dag, sem)
    fisher_dim = fd_result.fisher_dimension

    # Find sample complexity
    n_star = _find_sample_complexity_worker(
        dag, sem, task.n_values, task.n_trials,
        task.success_threshold, task.pc_alpha, task.seed
    )

    return {
        'family': task.family,
        'beta_range': task.beta_range,
        'num_edges': dag.num_edges(),
        'num_v_structures': len(dag.v_structures()),
        'fisher_dim': fisher_dim,
        'n_star': n_star,
    }


def _find_sample_complexity_worker(
    dag: DAG,
    sem: LinearGaussianSEM,
    n_values: List[int],
    n_trials: int,
    success_threshold: float,
    pc_alpha: float,
    base_seed: int
) -> float:
    """Find minimum sample size for successful recovery (worker function)."""
    true_cpdag = CPDAG.from_dag(dag)
    pc = PCAlgorithm(alpha=pc_alpha)
    rng = np.random.default_rng(base_seed)

    for n in n_values:
        successes = 0

        for trial in range(n_trials):
            seed = rng.integers(0, 2**31)
            X = sem.sample(n, random_state=seed)

            try:
                result = pc.fit(X)
                shd = structural_hamming_distance(true_cpdag, result.cpdag)
                if shd == 0:
                    successes += 1
            except Exception:
                continue

        success_rate = successes / n_trials
        if success_rate >= success_threshold:
            return float(n)

    return float(n_values[-1])


class Experiment1Correlation(BaseExperiment):
    """
    Experiment 1: Correlation between Fisher dimension and sample complexity.

    Setup:
    - Generate graphs from multiple families
    - Compute Fisher dimension for each
    - Run PC algorithm at varying sample sizes
    - Find minimum n* where SHD=0 with probability >= threshold

    Expected outcome: Strong positive correlation (r > 0.8)
    """

    @property
    def name(self) -> str:
        return "exp1_correlation"

    def run(self) -> Dict[str, Any]:
        """Execute the experiment."""
        exp_config = self.config.experiments.exp1_correlation

        # Parameters
        d = exp_config.d
        graphs_per_family = exp_config.graphs_per_family
        n_values = exp_config.n_values
        n_trials = exp_config.n_trials
        success_threshold = exp_config.success_threshold

        # 3 graph types for multi-panel comparison
        graph_types = ['tree', 'chain', 'erdos_renyi']

        self.logger.info(f"Running Experiment 1 with d={d}, {graphs_per_family} graphs per type")
        self.logger.info(f"Graph types: {graph_types}")
        self.logger.info(f"Using {self.get_n_jobs()} parallel workers")

        # Use varied beta ranges to get Fisher dimension variation (per Proposition 7.1)
        beta_ranges = [
            (0.2, 0.4),   # Weak coefficients -> high Fisher dim
            (0.3, 0.5),   # Medium-weak
            (0.4, 0.6),   # Medium
            (0.5, 0.7),   # Medium-strong
            (0.6, 0.9),   # Strong coefficients -> low Fisher dim
        ]

        results = {
            'config': {
                'd': d,
                'graphs_per_family': graphs_per_family,
                'n_values': n_values,
                'n_trials': n_trials,
                'success_threshold': success_threshold,
                'graph_types': graph_types,
                'beta_ranges': beta_ranges,
            },
            'by_graph_type': {gt: {'fisher_dims': [], 'sample_complexities': []}
                             for gt in graph_types},
        }

        # Create tasks for all graph types
        tasks = []
        for graph_type in graph_types:
            for i in range(graphs_per_family):
                seed = self.rng.integers(0, 2**31)
                er_p = 0.2 if graph_type == 'erdos_renyi' else None
                beta_range = beta_ranges[i % len(beta_ranges)]  # Cycle through beta ranges

                task = GraphTask(
                    family=graph_type,
                    d=d,
                    seed=seed,
                    er_p=er_p,
                    beta_range=beta_range,
                    sigma_range=(self.config.sem.sigma_min, self.config.sem.sigma_max),
                    n_values=n_values,
                    n_trials=n_trials,
                    success_threshold=success_threshold,
                    pc_alpha=self.config.pc.alpha,
                    sign_distribution=self.config.sem.sign_distribution,
                )
                tasks.append(task)

        # Process all graphs in parallel
        graph_results = self.parallel_map(
            _process_single_graph,
            tasks,
            desc="Processing graphs"
        )

        # Collect results by graph type
        for result in graph_results:
            if result is None:
                continue

            gt = result['family']
            results['by_graph_type'][gt]['fisher_dims'].append(result['fisher_dim'])
            results['by_graph_type'][gt]['sample_complexities'].append(result['n_star'])

        # Compute correlations per graph type
        delta = 1.0 - success_threshold
        log_factor = np.log(d) + np.log(1.0 / delta)
        max_n = max(n_values)

        stats_by_type = {}
        for gt, gt_data in results['by_graph_type'].items():
            fisher_arr = np.array(gt_data['fisher_dims'])
            n_star_arr = np.array(gt_data['sample_complexities'])

            # x = F * log(d/δ)
            x = fisher_arr * log_factor

            # Fit C (exclude ceiling)
            valid = np.isfinite(x) & np.isfinite(n_star_arr) & (x > 0) & (n_star_arr < max_n)
            if np.sum(valid) > 0:
                C_fit = float(np.median(n_star_arr[valid] / x[valid]))
            else:
                C_fit = 1.0

            gt_data['fitted_C'] = C_fit

            # Correlation
            if np.sum(valid) >= 3:
                pearson_r, _ = pearsonr(x[valid], n_star_arr[valid])
                spearman_r, _ = spearmanr(x[valid], n_star_arr[valid])
            else:
                pearson_r, spearman_r = float('nan'), float('nan')

            stats_by_type[gt] = {
                'fitted_C': C_fit,
                'pearson_r': pearson_r,
                'spearman_r': spearman_r,
                'n_graphs': len(fisher_arr),
            }

            self.logger.info(f"{gt}: C={C_fit:.1f}, r={pearson_r:.3f}, ρ={spearman_r:.3f}")

        results['stats_by_type'] = stats_by_type

        # Overall check
        all_correlations = [s['spearman_r'] for s in stats_by_type.values() if not np.isnan(s['spearman_r'])]
        results['meets_expectation'] = all(r > exp_config.expected_correlation for r in all_correlations) if all_correlations else False

        self.logger.info(f"Meets expectation (all r > {exp_config.expected_correlation}): {results['meets_expectation']}")

        return results

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """Generate experiment figures."""
        from ..utils.visualization import plot_exp1_by_graph_type, save_figure
        from pathlib import Path

        if not self.results:
            self.logger.warning("No results to plot")
            return

        output_dir = Path(output_dir or self.config.output.figures_dir)

        delta = 1.0 - self.results['config']['success_threshold']
        max_n = max(self.results['config']['n_values'])

        fig, axes = plot_exp1_by_graph_type(
            self.results['by_graph_type'],
            self.results['stats_by_type'],
            d=self.results['config']['d'],
            delta=delta,
            max_n=max_n,
        )

        save_figure(fig, output_dir / f"{self.name}_scatter")
        self.logger.info(f"Saved figure to {output_dir}")
