"""
Experiment 6: Benchmark Graph Analysis.

Computes Fisher dimension for standard benchmark causal graphs.
"""

from __future__ import annotations

from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
import numpy as np

from .base import BaseExperiment
from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension
from ..generators.benchmark import (
    generate_sachs_network,
    generate_child_network,
    generate_alarm_network,
    generate_insurance_network,
)
from ..algorithms.pc import PCAlgorithm
from ..algorithms.curvature import compute_curvature_matrix, estimate_fisher_dimension_from_curvature
from ..metrics.shd import structural_hamming_distance


@dataclass
class BenchmarkTask:
    """Task for processing a single SEM on a benchmark graph."""
    benchmark_name: str
    dag: DAG
    sem_idx: int
    seed: int
    beta_range: Tuple[float, float]
    sigma_range: Tuple[float, float]
    n_values: List[int]
    n_trials: int
    success_threshold: float
    pc_alpha: float
    sign_distribution: str = 'positive'


def _process_benchmark_task(task: BenchmarkTask) -> Dict[str, Any]:
    """Process a single SEM for benchmark analysis."""
    # Generate random SEM
    sem = LinearGaussianSEM.random(
        task.dag,
        beta_range=task.beta_range,
        sigma_range=task.sigma_range,
        random_state=task.seed,
        sign_distribution=task.sign_distribution
    )

    # Compute direct Fisher dimension
    fd_result = compute_fisher_dimension(task.dag, sem)

    # Find empirical sample complexity
    n_star = _find_sample_complexity_benchmark(
        task.dag, sem, task.n_values, task.n_trials,
        task.success_threshold, task.pc_alpha, task.seed
    )

    return {
        'benchmark_name': task.benchmark_name,
        'sem_idx': task.sem_idx,
        'fisher_dim': fd_result.fisher_dimension,
        'rho_min': fd_result.rho_min,
        'n_star': n_star,
    }


def _find_sample_complexity_benchmark(
    dag: DAG,
    sem: LinearGaussianSEM,
    n_values: List[int],
    n_trials: int,
    success_threshold: float,
    pc_alpha: float,
    base_seed: int
) -> float:
    """Find minimum sample size for successful recovery."""
    true_cpdag = CPDAG.from_dag(dag)
    pc = PCAlgorithm(alpha=pc_alpha)
    rng = np.random.default_rng(base_seed)

    for n in n_values:
        successes = 0
        for trial in range(n_trials):
            seed = rng.integers(0, 2**31)
            X = sem.sample(n, random_state=seed)
            try:
                result = pc.fit(X)
                shd = structural_hamming_distance(true_cpdag, result.cpdag)
                if shd == 0:
                    successes += 1
            except Exception:
                continue

        if successes / n_trials >= success_threshold:
            return float(n)

    return float(n_values[-1])


class Experiment6Benchmark(BaseExperiment):
    """
    Experiment 6: Benchmark Graph Analysis.

    Setup:
    - Analyze standard benchmark graphs: Sachs, Child, Alarm, Insurance
    - Compute Fisher dimension using direct and curvature methods
    - Compare predictions and validate against known difficulty

    Expected outcome: Fisher dimension correctly ranks relative difficulty.
    """

    @property
    def name(self) -> str:
        return "exp6_benchmark"

    def run(self) -> Dict[str, Any]:
        """Execute the experiment."""
        exp_config = self.config.experiments.exp6_benchmark

        n_sems_per_graph = exp_config.n_sems_per_graph
        n_values = exp_config.n_values
        n_trials = exp_config.n_trials
        success_threshold = exp_config.success_threshold

        self.logger.info(f"Running Experiment 6: Benchmark Graph Analysis")
        self.logger.info(f"Using {self.get_n_jobs()} parallel workers")

        # Benchmark graph generators
        benchmarks = {
            'Sachs': generate_sachs_network,
            'Child': generate_child_network,
            'Alarm': generate_alarm_network,
            'Insurance': generate_insurance_network,
        }

        results = {
            'config': {
                'n_sems_per_graph': n_sems_per_graph,
                'n_values': n_values,
                'n_trials': n_trials,
                'success_threshold': success_threshold,
            },
            'benchmarks': {},
            'names': [],
            'fisher_direct': [],
            'fisher_curvature': [],
            'empirical_n': [],
        }

        # Pre-generate all DAGs and compute curvature-based Fisher dim
        dag_data = {}
        for name, generator in benchmarks.items():
            self.logger.info(f"Loading benchmark: {name}")
            try:
                dag = generator()
                d = dag.num_nodes()

                # Compute curvature-based Fisher dimension
                try:
                    C = compute_curvature_matrix(dag)
                    fisher_curvature = estimate_fisher_dimension_from_curvature(C)
                except Exception as e:
                    self.logger.warning(f"Curvature computation failed for {name}: {e}")
                    fisher_curvature = float('nan')

                dag_data[name] = {
                    'dag': dag,
                    'd': d,
                    'num_edges': dag.num_edges(),
                    'num_v_structures': len(dag.v_structures()),
                    'fisher_curvature': fisher_curvature,
                }
            except Exception as e:
                self.logger.warning(f"Failed to generate {name}: {e}")
                continue

        # Create tasks for all SEMs across all benchmarks
        tasks = []
        for name, data in dag_data.items():
            d = data['d']
            adjusted_n = self._get_adjusted_n_values(n_values, d)

            for sem_idx in range(n_sems_per_graph):
                seed = self.rng.integers(0, 2**31)
                task = BenchmarkTask(
                    benchmark_name=name,
                    dag=data['dag'],
                    sem_idx=sem_idx,
                    seed=seed,
                    beta_range=(self.config.sem.beta_min, self.config.sem.beta_max),
                    sigma_range=(self.config.sem.sigma_min, self.config.sem.sigma_max),
                    n_values=adjusted_n,
                    n_trials=n_trials,
                    success_threshold=success_threshold,
                    pc_alpha=self.config.pc.alpha,
                    sign_distribution=self.config.sem.sign_distribution,
                )
                tasks.append(task)

        # Process all tasks in parallel
        task_results = self.parallel_map(
            _process_benchmark_task,
            tasks,
            desc="Processing benchmark SEMs"
        )

        # Aggregate results by benchmark
        benchmark_results = {name: {'fisher_dims': [], 'n_stars': []} for name in dag_data}

        for result in task_results:
            name = result['benchmark_name']
            benchmark_results[name]['fisher_dims'].append(result['fisher_dim'])
            benchmark_results[name]['n_stars'].append(result['n_star'])

        # Compute summary statistics
        for name, data in dag_data.items():
            br = benchmark_results[name]

            fisher_direct_mean = np.mean(br['fisher_dims']) if br['fisher_dims'] else float('nan')
            fisher_direct_std = np.std(br['fisher_dims']) if br['fisher_dims'] else float('nan')
            empirical_n_mean = np.mean(br['n_stars']) if br['n_stars'] else float('nan')
            empirical_n_std = np.std(br['n_stars']) if br['n_stars'] else float('nan')

            results['benchmarks'][name] = {
                'd': data['d'],
                'num_edges': data['num_edges'],
                'num_v_structures': data['num_v_structures'],
                'fisher_direct_mean': fisher_direct_mean,
                'fisher_direct_std': fisher_direct_std,
                'fisher_curvature': data['fisher_curvature'],
                'empirical_n_mean': empirical_n_mean,
                'empirical_n_std': empirical_n_std,
                'fisher_dims_all': br['fisher_dims'],
                'empirical_ns_all': br['n_stars'],
            }

            results['names'].append(name)
            results['fisher_direct'].append(fisher_direct_mean)
            results['fisher_curvature'].append(data['fisher_curvature'])
            results['empirical_n'].append(empirical_n_mean)

            self.logger.info(f"  {name}: d={data['d']}, edges={data['num_edges']}, "
                           f"F_direct={fisher_direct_mean:.2f}, "
                           f"F_curvature={data['fisher_curvature']:.2f}, "
                           f"n*={empirical_n_mean:.0f}")

        # Analysis
        results['analysis'] = self._analyze_rankings(results)
        results['meets_expectation'] = results['analysis'].get('ranking_consistent', False)

        return results

    def _get_adjusted_n_values(self, base_n_values: List[int], d: int) -> List[int]:
        """Adjust sample size range based on graph size."""
        scale_factor = max(1, d / 10)
        adjusted = [int(n * scale_factor) for n in base_n_values]
        max_n = max(adjusted)
        adjusted.extend([max_n * 2, max_n * 4, max_n * 8])
        return sorted(set(adjusted))

    def _analyze_rankings(self, results: Dict) -> Dict[str, Any]:
        """Analyze consistency of rankings between methods."""
        from scipy.stats import spearmanr, kendalltau

        names = results['names']
        fisher_direct = results['fisher_direct']
        fisher_curvature = results['fisher_curvature']
        empirical_n = results['empirical_n']

        analysis = {}

        valid_indices = [
            i for i in range(len(names))
            if np.isfinite(fisher_direct[i]) and np.isfinite(empirical_n[i])
        ]

        if len(valid_indices) >= 3:
            fd_valid = [fisher_direct[i] for i in valid_indices]
            emp_valid = [empirical_n[i] for i in valid_indices]

            corr_direct, p_direct = spearmanr(fd_valid, emp_valid)
            analysis['direct_empirical_correlation'] = corr_direct
            analysis['direct_empirical_pvalue'] = p_direct

            tau_direct, p_tau = kendalltau(fd_valid, emp_valid)
            analysis['direct_empirical_kendall'] = tau_direct

        valid_both = [
            i for i in range(len(names))
            if np.isfinite(fisher_direct[i]) and np.isfinite(fisher_curvature[i])
        ]

        if len(valid_both) >= 3:
            fd_both = [fisher_direct[i] for i in valid_both]
            fc_both = [fisher_curvature[i] for i in valid_both]

            corr_methods, p_methods = spearmanr(fd_both, fc_both)
            analysis['direct_curvature_correlation'] = corr_methods
            analysis['direct_curvature_pvalue'] = p_methods

        if len(valid_indices) >= 2:
            fd_ranking = np.argsort(np.argsort([fisher_direct[i] for i in valid_indices]))
            emp_ranking = np.argsort(np.argsort([empirical_n[i] for i in valid_indices]))

            analysis['ranking_consistent'] = np.allclose(fd_ranking, emp_ranking)

            inversions = sum(
                1 for i in range(len(fd_ranking))
                for j in range(i+1, len(fd_ranking))
                if (fd_ranking[i] < fd_ranking[j]) != (emp_ranking[i] < emp_ranking[j])
            )
            max_inversions = len(fd_ranking) * (len(fd_ranking) - 1) // 2
            analysis['ranking_inversion_rate'] = inversions / max_inversions if max_inversions > 0 else 0
        else:
            analysis['ranking_consistent'] = True
            analysis['ranking_inversion_rate'] = 0

        return analysis

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """Generate experiment figures."""
        from ..utils.visualization import plot_benchmark_comparison, save_figure
        from pathlib import Path
        import matplotlib.pyplot as plt

        if not self.results:
            self.logger.warning("No results to plot")
            return

        output_dir = Path(output_dir or self.config.output.figures_dir)

        # Plot 1: Fisher dimension comparison
        fig1, ax1 = plt.subplots(figsize=(12, 8))

        plot_benchmark_comparison(
            self.results['names'],
            self.results['fisher_direct'],
            self.results['fisher_curvature'],
            ax=ax1
        )

        save_figure(fig1, output_dir / f"{self.name}_comparison")

        # Plot 2: Fisher dimension vs empirical n*
        fig2, ax2 = plt.subplots(figsize=(10, 8))

        names = self.results['names']
        fisher_direct = self.results['fisher_direct']
        empirical_n = self.results['empirical_n']

        d_values = [self.results['benchmarks'][name]['d'] for name in names]
        x = [f * np.log(d) for f, d in zip(fisher_direct, d_values)]

        colors = plt.cm.Set1(np.linspace(0, 1, len(names)))

        for i, name in enumerate(names):
            ax2.scatter(x[i], empirical_n[i], c=[colors[i]], s=100, label=name)
            ax2.annotate(name, (x[i], empirical_n[i]), xytext=(5, 5),
                        textcoords='offset points', fontsize=9)

        valid = [i for i in range(len(names)) if np.isfinite(x[i]) and np.isfinite(empirical_n[i])]
        if valid:
            min_val = min(min(x[i] for i in valid), min(empirical_n[i] for i in valid))
            max_val = max(max(x[i] for i in valid), max(empirical_n[i] for i in valid))
            ax2.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, label='y=x')

        ax2.set_xlabel(r'$\mathcal{F}([G]) \cdot \log d$')
        ax2.set_ylabel('Empirical Sample Complexity $n^*$')
        ax2.set_title('Benchmark Graphs: Fisher Dimension vs Sample Complexity')
        ax2.legend()

        save_figure(fig2, output_dir / f"{self.name}_scatter")

        # Plot 3: Summary table
        fig3, ax3 = plt.subplots(figsize=(14, 6))
        ax3.axis('off')

        table_data = []
        headers = ['Benchmark', 'd', 'Edges', 'V-structs', 'F (direct)', 'F (curvature)', 'n*']

        for name in names:
            bench = self.results['benchmarks'][name]
            row = [
                name,
                bench['d'],
                bench['num_edges'],
                bench['num_v_structures'],
                f"{bench['fisher_direct_mean']:.1f} +/- {bench['fisher_direct_std']:.1f}",
                f"{bench['fisher_curvature']:.1f}" if np.isfinite(bench['fisher_curvature']) else 'N/A',
                f"{bench['empirical_n_mean']:.0f} +/- {bench['empirical_n_std']:.0f}",
            ]
            table_data.append(row)

        table = ax3.table(
            cellText=table_data,
            colLabels=headers,
            cellLoc='center',
            loc='center',
            colColours=['#f0f0f0'] * len(headers)
        )
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1.2, 1.5)

        ax3.set_title('Benchmark Graph Summary', pad=20)

        save_figure(fig3, output_dir / f"{self.name}_table")

        self.logger.info(f"Saved figures to {output_dir}")
