"""
Experiment 5: Information-Theoretic Lower Bound Verification.

Empirically verifies the lower bound from Theorem 4.2.
"""

from __future__ import annotations

from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass
import numpy as np
from scipy.stats import norm

from .base import BaseExperiment
from ..core.dag import DAG
from ..core.sem import LinearGaussianSEM
from ..core.mec import CPDAG
from ..core.fisher_dimension import compute_fisher_dimension
from ..core.partial_correlation import partial_correlation
from ..generators.factory import generate_dag
from ..algorithms.pc import PCAlgorithm
from ..metrics.shd import structural_hamming_distance


@dataclass
class LowerBoundTask:
    """Task for processing a single DAG pair."""
    d: int
    seed: int
    n_values: List[int]
    n_trials: int
    target_power: float


def _process_lower_bound_task(task: LowerBoundTask) -> Optional[Dict[str, Any]]:
    """Process a single DAG pair for lower bound verification."""
    rng = np.random.default_rng(task.seed)

    # Construct hard pair
    pair_result = _construct_hard_pair(task.d, task.seed)
    if pair_result is None:
        return None

    dag0, dag1, sem0, sem1, edge_info = pair_result

    # Compute Fisher dimension
    fd_result = compute_fisher_dimension(dag0, sem0)
    fisher_dim = fd_result.fisher_dimension

    # Theoretical lower bound
    c_lower = 2.0
    theoretical_lower = c_lower * fisher_dim

    # Find empirical sample complexity
    empirical_n = _find_distinguishing_sample_size(
        sem0, edge_info, task.n_values, task.n_trials, task.target_power, task.seed
    )

    return {
        'edge': edge_info['edge'],
        'rho': edge_info['partial_correlation'],
        'num_edges_g0': dag0.num_edges(),
        'fisher_dim': fisher_dim,
        'theoretical_lower': theoretical_lower,
        'empirical_n': empirical_n,
    }


def _construct_hard_pair(d: int, seed: int) -> Optional[Tuple]:
    """Construct a pair of DAGs that differ by a single edge."""
    rng = np.random.default_rng(seed)

    # Start with a sparse Erdos-Renyi DAG
    base_dag = generate_dag('erdos_renyi', d, random_state=seed, p=0.2)

    if base_dag.num_edges() < 2:
        return None

    edges = list(base_dag.edges)
    target_edge = edges[rng.integers(len(edges))]
    parent, child = target_edge

    # Create G0 (with edge) and G1 (without edge)
    dag0 = DAG(d, edges)
    dag1_edges = [e for e in edges if e != target_edge]
    dag1 = DAG(d, dag1_edges)

    # Check MECs are different
    cpdag0 = CPDAG.from_dag(dag0)
    cpdag1 = CPDAG.from_dag(dag1)

    if cpdag0 == cpdag1:
        return None

    # Create SEMs with small coefficient for target edge
    small_beta = 0.1
    betas0 = {}
    for edge in edges:
        if edge == target_edge:
            betas0[edge] = small_beta * rng.choice([-1, 1])
        else:
            betas0[edge] = rng.uniform(0.3, 0.7) * rng.choice([-1, 1])

    sigmas0 = {i: 1.0 for i in range(d)}
    sem0 = LinearGaussianSEM(dag0, betas0, sigmas0)

    betas1 = {e: betas0[e] for e in dag1_edges}
    sigmas1 = sigmas0.copy()
    sem1 = LinearGaussianSEM(dag1, betas1, sigmas1)

    # Compute partial correlation
    Sigma0 = sem0.covariance_matrix()
    S = dag0.parents(child) - {parent}
    rho = partial_correlation(Sigma0, parent, child, S)

    edge_info = {
        'edge': target_edge,
        'partial_correlation': rho,
        'beta': betas0[target_edge],
        'conditioning_set': list(S),
    }

    return dag0, dag1, sem0, sem1, edge_info


def _find_distinguishing_sample_size(
    sem0: LinearGaussianSEM,
    edge_info: Dict,
    n_values: List[int],
    n_trials: int,
    target_power: float,
    base_seed: int
) -> float:
    """Find minimum sample size to distinguish SEMs."""
    parent, child = edge_info['edge']
    S = set(edge_info['conditioning_set'])
    rng = np.random.default_rng(base_seed)

    for n in n_values:
        power = _estimate_test_power(sem0, parent, child, S, n, n_trials, rng)
        if power >= target_power:
            return float(n)

    return float(n_values[-1])


def _estimate_test_power(
    sem0: LinearGaussianSEM,
    i: int,
    j: int,
    S: set,
    n: int,
    n_trials: int,
    rng: np.random.Generator
) -> float:
    """Estimate test power for distinguishing edge presence."""
    alpha = 0.05
    correct_detections = 0

    for trial in range(n_trials):
        seed = rng.integers(0, 2**31)
        X = sem0.sample(n, random_state=seed)

        sample_cov = np.cov(X, rowvar=False)

        try:
            rho_hat = partial_correlation(sample_cov, i, j, S)
            z = 0.5 * np.log((1 + rho_hat) / (1 - rho_hat + 1e-10))
            z_se = 1.0 / np.sqrt(n - len(S) - 3) if n > len(S) + 3 else 1.0
            z_stat = abs(z) / z_se
            p_value = 2 * (1 - norm.cdf(z_stat))

            if p_value < alpha:
                correct_detections += 1
        except Exception:
            continue

    return correct_detections / n_trials


class Experiment5LowerBound(BaseExperiment):
    """
    Experiment 5: Information-Theoretic Lower Bound Verification.

    Setup:
    - Construct pairs of DAGs with different MECs that differ by one edge
    - Choose edge coefficients to minimize KL divergence
    - Estimate sample complexity required to distinguish

    Expected outcome: Empirical complexity exceeds the lower bound.
    """

    @property
    def name(self) -> str:
        return "exp5_lower_bound"

    def run(self) -> Dict[str, Any]:
        """Execute the experiment."""
        exp_config = self.config.experiments.exp5_lower_bound

        d = exp_config.d
        num_pairs = exp_config.num_pairs
        n_values = exp_config.n_values
        n_trials = exp_config.n_trials
        target_power = exp_config.target_power

        self.logger.info(f"Running Experiment 5 with d={d}, {num_pairs} DAG pairs")
        self.logger.info(f"Using {self.get_n_jobs()} parallel workers")

        results = {
            'config': {
                'd': d,
                'num_pairs': num_pairs,
                'n_values': n_values,
                'n_trials': n_trials,
                'target_power': target_power,
            },
            'pairs': [],
            'fisher_dims': [],
            'theoretical_lower': [],
            'empirical_n': [],
        }

        # Create tasks
        tasks = []
        for pair_idx in range(num_pairs):
            seed = self.rng.integers(0, 2**31)
            task = LowerBoundTask(
                d=d,
                seed=seed,
                n_values=n_values,
                n_trials=n_trials,
                target_power=target_power,
            )
            tasks.append(task)

        # Process all tasks in parallel
        task_results = self.parallel_map(
            _process_lower_bound_task,
            tasks,
            desc="Processing DAG pairs"
        )

        # Collect results
        for result in task_results:
            if result is None:
                continue

            results['pairs'].append({
                'edge': result['edge'],
                'rho': result['rho'],
                'num_edges_g0': result['num_edges_g0'],
            })
            results['fisher_dims'].append(result['fisher_dim'])
            results['theoretical_lower'].append(result['theoretical_lower'])
            results['empirical_n'].append(result['empirical_n'])

        # Analysis
        fisher_dims = np.array(results['fisher_dims'])
        theoretical_lower = np.array(results['theoretical_lower'])
        empirical_n = np.array(results['empirical_n'])

        valid = np.isfinite(fisher_dims) & np.isfinite(empirical_n)
        if np.sum(valid) > 0:
            exceeds_lower = empirical_n[valid] >= theoretical_lower[valid]
            fraction_exceeds = np.mean(exceeds_lower)
            ratios = empirical_n[valid] / (theoretical_lower[valid] + 1e-10)
            mean_ratio = np.mean(ratios)
            min_ratio = np.min(ratios)
        else:
            fraction_exceeds = float('nan')
            mean_ratio = float('nan')
            min_ratio = float('nan')

        results['analysis'] = {
            'fraction_exceeds_lower_bound': fraction_exceeds,
            'mean_ratio_to_lower_bound': mean_ratio,
            'min_ratio_to_lower_bound': min_ratio,
        }

        results['meets_expectation'] = fraction_exceeds >= exp_config.expected_fraction_exceeding

        self.logger.info(f"Fraction exceeding lower bound: {fraction_exceeds:.2%}")
        self.logger.info(f"Mean ratio to lower bound: {mean_ratio:.2f}")
        self.logger.info(f"Meets expectation: {results['meets_expectation']}")

        return results

    def plot_results(self, output_dir: Optional[str] = None) -> None:
        """Generate experiment figures."""
        from ..utils.visualization import plot_lower_bound_verification, save_figure
        from pathlib import Path
        import matplotlib.pyplot as plt

        if not self.results:
            self.logger.warning("No results to plot")
            return

        output_dir = Path(output_dir or self.config.output.figures_dir)

        fig, ax = plt.subplots(figsize=(10, 8))

        fisher_dims = np.array(self.results['fisher_dims'])
        empirical_n = np.array(self.results['empirical_n'])
        theoretical_lower = np.array(self.results['theoretical_lower'])

        plot_lower_bound_verification(
            fisher_dims, empirical_n, theoretical_lower, ax=ax
        )

        analysis = self.results.get('analysis', {})
        ax.annotate(
            f"Exceeds lower: {analysis.get('fraction_exceeds_lower_bound', 0):.1%}\n"
            f"Mean ratio: {analysis.get('mean_ratio_to_lower_bound', 0):.2f}x",
            xy=(0.05, 0.95), xycoords='axes fraction',
            verticalalignment='top',
            fontsize=10,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

        save_figure(fig, output_dir / f"{self.name}_verification")
        self.logger.info(f"Saved figure to {output_dir}")
