"""
Experiment 4: Scaling with Graph Size.

Verifies the log(d) dependence in sample complexity.
"""

from __future__ import annotations

from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
import numpy as np
from scipy.optimize import curve_fit
from scipy.stats import pearsonr

from .base import BaseExperiment
from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension
from ..generators.factory import generate_dag
from ..algorithms.pc import PCAlgorithm
from ..metrics.shd import structural_hamming_distance


@dataclass
class ScalingTask:
    """Task for processing a single graph in scaling experiment."""
    family: str
    d: int
    seed: int
    beta_range: Tuple[float, float]
    sigma_range: Tuple[float, float]
    n_values: List[int]
    n_trials: int
    success_threshold: float
    pc_alpha: float
    sign_distribution: str = 'positive'


def _process_scaling_task(task: ScalingTask) -> Optional[Dict[str, Any]]:
    """Process a single graph for scaling experiment."""
    # Generate DAG
    if task.family == 'erdos_renyi':
        dag = generate_dag(task.family, task.d, random_state=task.seed, p=0.2)
    else:
        dag = generate_dag(task.family, task.d, random_state=task.seed)

    # Skip graphs with no edges
    if dag.num_edges() == 0:
        return None

    # Generate SEM
    sem = LinearGaussianSEM.random(
        dag,
        beta_range=task.beta_range,
        sigma_range=task.sigma_range,
        random_state=task.seed,
        sign_distribution=task.sign_distribution
    )

    # Compute Fisher dimension
    fd_result = compute_fisher_dimension(dag, sem)
    fisher_dim = fd_result.fisher_dimension

    # Find sample complexity
    n_star = _find_sample_complexity_scaling(
        dag, sem, task.n_values, task.n_trials,
        task.success_threshold, task.pc_alpha, task.seed
    )

    return {
        'family': task.family,
        'd': task.d,
        'fisher_dim': fisher_dim,
        'n_star': n_star,
    }


def _find_sample_complexity_scaling(
    dag: DAG,
    sem: LinearGaussianSEM,
    n_values: List[int],
    n_trials: int,
    success_threshold: float,
    pc_alpha: float,
    base_seed: int
) -> float:
    """Find minimum sample size for successful recovery."""
    true_cpdag = CPDAG.from_dag(dag)
    pc = PCAlgorithm(alpha=pc_alpha)
    rng = np.random.default_rng(base_seed)

    for n in n_values:
        successes = 0
        for trial in range(n_trials):
            seed = rng.integers(0, 2**31)
            X = sem.sample(n, random_state=seed)
            try:
                result = pc.fit(X)
                shd = structural_hamming_distance(true_cpdag, result.cpdag)
                if shd == 0:
                    successes += 1
            except Exception:
                continue

        if successes / n_trials >= success_threshold:
            return float(n)

    return float(n_values[-1])


class Experiment4Scaling(BaseExperiment):
    """
    Experiment 4: Scaling with Graph Size.

    Setup:
    - For graph sizes d in {5, 10, 20, 50}
    - Generate Erdos-Renyi, chain, and tree DAGs
    - Compute average Fisher dimension and sample complexity
    - Verify n* / (F([G]) * log(d)) is approximately constant

    Expected outcome: Confirms scaling predictions of Theorems 4.1-4.2.
    """

    @property
    def name(self) -> str:
        return "exp4_scaling"

    def run(self) -> Dict[str, Any]:
        """Execute the experiment."""
        exp_config = self.config.experiments.exp4_scaling

        d_values = exp_config.d_values
        graphs_per_d = exp_config.graphs_per_d
        n_values = exp_config.n_values
        n_trials = exp_config.n_trials
        success_threshold = exp_config.success_threshold

        families = ['erdos_renyi', 'chain', 'tree']  # Excluding complete - too hard for PC

        self.logger.info(f"Running Experiment 4 with d={d_values}")
        self.logger.info(f"Using {self.get_n_jobs()} parallel workers")

        results = {
            'config': {
                'd_values': d_values,
                'graphs_per_d': graphs_per_d,
                'n_values': n_values,
                'n_trials': n_trials,
                'success_threshold': success_threshold,
                'families': families,
            },
            'd_values': d_values,
            'families': families,
            'data_by_family': {family: {'d': [], 'fisher_dims': [], 'sample_complexities': []}
                              for family in families},
            'sample_complexities': {},
            'scaling_analysis': {},
        }

        # Create tasks for all combinations
        tasks = []
        for family in families:
            for d in d_values:
                # Adjust n_values based on d
                adjusted_n = self._get_adjusted_n_values(n_values, d)
                for i in range(graphs_per_d):
                    seed = self.rng.integers(0, 2**31)
                    task = ScalingTask(
                        family=family,
                        d=d,
                        seed=seed,
                        beta_range=(self.config.sem.beta_min, self.config.sem.beta_max),
                        sigma_range=(self.config.sem.sigma_min, self.config.sem.sigma_max),
                        n_values=adjusted_n,
                        n_trials=n_trials,
                        success_threshold=success_threshold,
                        pc_alpha=self.config.pc.alpha,
                        sign_distribution=self.config.sem.sign_distribution,
                    )
                    tasks.append(task)

        # Process all tasks in parallel
        task_results = self.parallel_map(
            _process_scaling_task,
            tasks,
            desc="Processing graphs"
        )

        # Collect and aggregate results by (family, d)
        raw_data = {family: {d: {'fisher_dims': [], 'n_stars': []} for d in d_values}
                    for family in families}

        for result in task_results:
            if result is None:
                continue
            family = result['family']
            d = result['d']
            raw_data[family][d]['fisher_dims'].append(result['fisher_dim'])
            raw_data[family][d]['n_stars'].append(result['n_star'])

        # Compute averages
        for family in families:
            family_d = []
            family_fisher = []
            family_n_star = []

            for d in d_values:
                if raw_data[family][d]['fisher_dims']:
                    family_d.append(d)
                    family_fisher.append(np.mean(raw_data[family][d]['fisher_dims']))
                    family_n_star.append(np.mean(raw_data[family][d]['n_stars']))

            results['data_by_family'][family] = {
                'd': family_d,
                'fisher_dims': family_fisher,
                'sample_complexities': family_n_star,
            }

            if family_d:
                results['sample_complexities'][family] = np.array(family_n_star)

        # Analyze scaling
        for family in families:
            data = results['data_by_family'][family]
            if len(data['d']) < 3:
                continue

            d_arr = np.array(data['d'])
            f_arr = np.array(data['fisher_dims'])
            n_arr = np.array(data['sample_complexities'])

            log_d = np.log(d_arr)
            normalized = n_arr / (f_arr * log_d + 1e-10)

            try:
                def scaling_model(x, C):
                    d, f = x
                    return C * f * np.log(d)

                popt, _ = curve_fit(
                    lambda x, C: scaling_model((x[:len(d_arr)], x[len(d_arr):]), C),
                    np.concatenate([d_arr, f_arr]),
                    n_arr,
                    p0=[1.0]
                )
                fitted_C = popt[0]
            except Exception:
                fitted_C = np.mean(normalized)

            normalized_std = np.std(normalized)
            normalized_mean = np.mean(normalized)
            cv = normalized_std / normalized_mean if normalized_mean > 0 else float('inf')

            if len(d_arr) >= 3:
                corr_log_d, _ = pearsonr(log_d, n_arr)
            else:
                corr_log_d = float('nan')

            results['scaling_analysis'][family] = {
                'fitted_C': fitted_C,
                'normalized_mean': normalized_mean,
                'normalized_std': normalized_std,
                'coefficient_of_variation': cv,
                'correlation_with_log_d': corr_log_d,
                'is_consistent': cv < 0.5,
            }

        consistent_count = sum(
            1 for analysis in results['scaling_analysis'].values()
            if analysis.get('is_consistent', False)
        )
        results['overall_consistent'] = consistent_count >= len(families) // 2
        results['meets_expectation'] = results['overall_consistent']

        self.logger.info(f"Scaling analysis: {consistent_count}/{len(families)} families consistent")
        for family, analysis in results['scaling_analysis'].items():
            self.logger.info(f"  {family}: CV={analysis['coefficient_of_variation']:.3f}, "
                           f"consistent={analysis['is_consistent']}")

        return results

    def _get_adjusted_n_values(self, base_n_values: List[int], d: int) -> List[int]:
        """Adjust sample size range based on graph size."""
        scale_factor = max(1, d / 10)
        adjusted = [int(n * scale_factor) for n in base_n_values]
        return sorted(set(adjusted))

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """Generate experiment figures."""
        from ..utils.visualization import plot_scaling, save_figure
        from pathlib import Path
        import matplotlib.pyplot as plt

        if not self.results:
            self.logger.warning("No results to plot")
            return

        output_dir = Path(output_dir or self.config.output.figures_dir)

        # Plot 1: Sample complexity scaling
        fig1, ax1 = plt.subplots(figsize=(10, 8))
        plot_scaling(
            self.results['d_values'],
            self.results['sample_complexities'],
            ax=ax1
        )
        save_figure(fig1, output_dir / f"{self.name}_scaling")

        # Plot 2: Normalized sample complexity
        fig2, ax2 = plt.subplots(figsize=(10, 8))

        colors = plt.cm.tab10(np.linspace(0, 1, len(self.results['families'])))

        for i, family in enumerate(self.results['families']):
            data = self.results['data_by_family'][family]
            if not data['d']:
                continue

            d_arr = np.array(data['d'])
            f_arr = np.array(data['fisher_dims'])
            n_arr = np.array(data['sample_complexities'])

            normalized = n_arr / (f_arr * np.log(d_arr) + 1e-10)
            ax2.plot(d_arr, normalized, 'o-', color=colors[i], label=family, markersize=8)

        ax2.set_xlabel('Number of nodes $d$')
        ax2.set_ylabel(r'$n^* / (\mathcal{F}([G]) \cdot \log d)$')
        ax2.set_title('Normalized Sample Complexity (should be ~constant)')
        ax2.legend()
        ax2.axhline(y=1.0, color='k', linestyle='--', alpha=0.5, label='Reference')

        save_figure(fig2, output_dir / f"{self.name}_normalized")

        # Plot 3: Fisher dimension scaling
        fig3, ax3 = plt.subplots(figsize=(10, 8))

        for i, family in enumerate(self.results['families']):
            data = self.results['data_by_family'][family]
            if not data['d']:
                continue

            ax3.plot(data['d'], data['fisher_dims'], 'o-',
                    color=colors[i], label=family, markersize=8)

        ax3.set_xlabel('Number of nodes $d$')
        ax3.set_ylabel(r'Fisher Dimension $\mathcal{F}([G])$')
        ax3.set_title('Fisher Dimension Scaling with Graph Size')
        ax3.legend()

        save_figure(fig3, output_dir / f"{self.name}_fisher_scaling")

        self.logger.info(f"Saved figures to {output_dir}")
