"""
Experiment 3: Comparison with Alternative Complexity Proxies.

Demonstrates that Fisher dimension outperforms simpler structural proxies.
"""

from __future__ import annotations

from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
import numpy as np
from scipy.stats import spearmanr

from .base import BaseExperiment
from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension
from ..generators.factory import generate_dag
from ..algorithms.pc import PCAlgorithm
from ..algorithms.curvature import compute_curvature_matrix, estimate_fisher_dimension_from_curvature
from ..metrics.shd import structural_hamming_distance
from ..metrics.complexity_proxies import compute_all_proxies


@dataclass
class ProxyTask:
    """Task for processing a single graph in proxy comparison."""
    graph_type: str
    d: int
    seed: int
    er_p: Optional[float]
    beta_range: Tuple[float, float]
    sigma_range: Tuple[float, float]
    n_values: List[int]
    n_trials: int
    success_threshold: float
    pc_alpha: float
    sign_distribution: str = 'positive'


def _process_proxy_task(task: ProxyTask) -> Optional[Dict[str, Any]]:
    """Process a single graph for proxy comparison."""
    # Generate DAG
    if task.graph_type == 'erdos_renyi':
        dag = generate_dag(task.graph_type, task.d, random_state=task.seed, p=task.er_p)
    else:
        dag = generate_dag(task.graph_type, task.d, random_state=task.seed)

    # Skip graphs with no edges
    if dag.num_edges() == 0:
        return None

    # Generate SEM
    sem = LinearGaussianSEM.random(
        dag,
        beta_range=task.beta_range,
        sigma_range=task.sigma_range,
        random_state=task.seed,
        sign_distribution=task.sign_distribution
    )

    # Compute all complexity proxies
    proxies = compute_all_proxies(dag, sem)

    # Compute Fisher dimension (partial correlation method)
    fd_result = compute_fisher_dimension(dag, sem)
    proxies['fisher_dimension'] = fd_result.fisher_dimension
    proxies['rho_min'] = fd_result.rho_min

    # Compute curvature-based estimate
    try:
        C = compute_curvature_matrix(dag)
        curvature_estimate = estimate_fisher_dimension_from_curvature(C)
        proxies['curvature_based'] = curvature_estimate
    except Exception:
        proxies['curvature_based'] = float('nan')

    # Find empirical sample complexity
    n_star = _find_sample_complexity_proxy(
        dag, sem, task.n_values, task.n_trials,
        task.success_threshold, task.pc_alpha, task.seed
    )

    return {
        'd': task.d,
        'graph_type': task.graph_type,
        'num_edges': dag.num_edges(),
        'empirical_n': n_star,
        **proxies
    }


def _find_sample_complexity_proxy(
    dag: DAG,
    sem: LinearGaussianSEM,
    n_values: List[int],
    n_trials: int,
    success_threshold: float,
    pc_alpha: float,
    base_seed: int
) -> float:
    """Find minimum sample size for successful recovery."""
    true_cpdag = CPDAG.from_dag(dag)
    pc = PCAlgorithm(alpha=pc_alpha)
    rng = np.random.default_rng(base_seed)

    for n in n_values:
        successes = 0
        for trial in range(n_trials):
            seed = rng.integers(0, 2**31)
            X = sem.sample(n, random_state=seed)
            try:
                result = pc.fit(X)
                shd = structural_hamming_distance(true_cpdag, result.cpdag)
                if shd == 0:
                    successes += 1
            except Exception:
                continue

        if successes / n_trials >= success_threshold:
            return float(n)

    return float(n_values[-1])


class Experiment3Proxies(BaseExperiment):
    """
    Experiment 3: Comparison with Alternative Complexity Proxies.

    Setup:
    - For each graph type, generate graphs with varied beta ranges
    - Compute multiple predictors including Fisher dimension
    - Measure empirical sample complexity per graph type

    Expected outcome: Fisher dimension achieves highest correlation within each graph type.
    """

    @property
    def name(self) -> str:
        return "exp3_proxies"

    def run(self) -> Dict[str, Any]:
        """Execute the experiment."""
        exp_config = self.config.experiments.exp3_proxies

        d_values = exp_config.d_values
        graphs_per_d = exp_config.graphs_per_d
        n_values = exp_config.n_values
        n_trials = exp_config.n_trials
        success_threshold = exp_config.success_threshold

        # 3 graph types for multi-panel comparison
        graph_types = ['tree', 'chain', 'erdos_renyi']

        # Varied beta ranges to get Fisher dimension variation within each type
        beta_ranges = [
            (0.2, 0.4),
            (0.3, 0.5),
            (0.4, 0.6),
            (0.5, 0.7),
            (0.6, 0.9),
        ]

        self.logger.info(f"Running Experiment 3 with d={d_values}, {graphs_per_d} graphs per (type, d)")
        self.logger.info(f"Graph types: {graph_types}")
        self.logger.info(f"Using {self.get_n_jobs()} parallel workers")

        results = {
            'config': {
                'd_values': d_values,
                'graphs_per_d': graphs_per_d,
                'n_values': n_values,
                'n_trials': n_trials,
                'success_threshold': success_threshold,
                'graph_types': graph_types,
            },
            'by_graph_type': {
                gt: {'data': [], 'correlations': {}}
                for gt in graph_types
            },
        }

        # Create tasks for all (graph_type, d) combinations
        tasks = []
        for gt in graph_types:
            for d in d_values:
                for i in range(graphs_per_d):
                    seed = self.rng.integers(0, 2**31)
                    er_p = 0.2 if gt == 'erdos_renyi' else None
                    beta_range = beta_ranges[i % len(beta_ranges)]  # Cycle through beta ranges

                    task = ProxyTask(
                        graph_type=gt,
                        d=d,
                        seed=seed,
                        er_p=er_p,
                        beta_range=beta_range,
                        sigma_range=(self.config.sem.sigma_min, self.config.sem.sigma_max),
                        n_values=n_values,
                        n_trials=n_trials,
                        success_threshold=success_threshold,
                        pc_alpha=self.config.pc.alpha,
                        sign_distribution=self.config.sem.sign_distribution,
                    )
                    tasks.append(task)

        # Process all tasks in parallel
        task_results = self.parallel_map(
            _process_proxy_task,
            tasks,
            desc="Processing graphs"
        )

        # Collect results by graph type
        for result in task_results:
            if result is not None:
                gt = result['graph_type']
                results['by_graph_type'][gt]['data'].append(result)

        # Compute correlations per graph type
        for gt in graph_types:
            gt_data = results['by_graph_type'][gt]['data']
            correlations = self._compute_correlations(gt_data)
            results['by_graph_type'][gt]['correlations'] = correlations

            # Find best proxy for this graph type
            if correlations:
                best = max(correlations.items(), key=lambda x: abs(x[1]) if not np.isnan(x[1]) else -np.inf)
                results['by_graph_type'][gt]['best_proxy'] = best[0]
                results['by_graph_type'][gt]['best_correlation'] = best[1]
                fisher_corr = correlations.get('fisher_dimension', float('nan'))
                results['by_graph_type'][gt]['fisher_correlation'] = fisher_corr

                self.logger.info(f"{gt}: best={best[0]} (r={best[1]:.3f}), fisher_dim r={fisher_corr:.3f}")

        # Overall check: Fisher dimension should be best or near-best in each type
        fisher_wins = 0
        for gt in graph_types:
            gt_results = results['by_graph_type'][gt]
            if gt_results.get('best_proxy') == 'fisher_dimension':
                fisher_wins += 1
            elif not np.isnan(gt_results.get('fisher_correlation', float('nan'))):
                # Check if within 90% of best
                best_corr = abs(gt_results.get('best_correlation', 0))
                fisher_corr = abs(gt_results.get('fisher_correlation', 0))
                if best_corr > 0 and fisher_corr >= 0.9 * best_corr:
                    fisher_wins += 1

        results['fisher_wins'] = fisher_wins
        results['meets_expectation'] = fisher_wins >= 2  # Win in at least 2 of 3 types

        self.logger.info(f"Fisher dimension wins/near-wins: {fisher_wins}/{len(graph_types)}")
        self.logger.info(f"Meets expectation: {results['meets_expectation']}")

        return results

    def _compute_correlations(self, data: List[Dict]) -> Dict[str, float]:
        """Compute Spearman correlations between each proxy and empirical n*."""
        if not data:
            return {}

        empirical_n = [d['empirical_n'] for d in data]

        proxy_names = [
            'fisher_dimension',
            'graph_density',
            'max_in_degree',
            'avg_markov_blanket_size',
            'num_v_structures',
            'mec_size_estimate',
            'curvature_based',
            'num_edges',
            'max_out_degree',
            'avg_degree',
        ]

        correlations = {}
        for proxy_name in proxy_names:
            proxy_values = [d.get(proxy_name, float('nan')) for d in data]

            valid = [
                (p, e) for p, e in zip(proxy_values, empirical_n)
                if np.isfinite(p) and np.isfinite(e)
            ]

            if len(valid) >= 3:
                p_vals, e_vals = zip(*valid)
                corr, _ = spearmanr(p_vals, e_vals)
                correlations[proxy_name] = corr
            else:
                correlations[proxy_name] = float('nan')

        return correlations

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """Generate experiment figures."""
        from ..utils.visualization import plot_exp3_by_graph_type, save_figure
        from pathlib import Path

        if not self.results:
            self.logger.warning("No results to plot")
            return

        output_dir = Path(output_dir or self.config.output.figures_dir)

        # Plot 3x1 panel of proxy correlations per graph type
        fig, axes = plot_exp3_by_graph_type(self.results['by_graph_type'])
        save_figure(fig, output_dir / f"{self.name}_correlations")

        self.logger.info(f"Saved figure to {output_dir}")
