"""
Linear Gaussian Structural Equation Model (SEM) implementation.

This module provides the LinearGaussianSEM class for generating data
from linear Gaussian structural equation models and computing
their covariance matrices analytically.
"""

from __future__ import annotations

from typing import Dict, Tuple, Optional, List
import numpy as np
from scipy import linalg

from .dag import DAG


class LinearGaussianSEM:
    """
    Linear Gaussian Structural Equation Model.

    Represents the model:
        X_i = Σ_{j ∈ Pa(i)} β_ji * X_j + ε_i
    where ε_i ~ N(0, σ_i²)

    Attributes:
        dag: The underlying causal DAG
        coefficients: Edge coefficients β_ji for each edge (j, i)
        noise_variances: Noise variance σ_i² for each node
    """

    def __init__(
        self,
        dag: DAG,
        coefficients: Dict[Tuple[int, int], float],
        noise_variances: np.ndarray
    ):
        """
        Initialize a Linear Gaussian SEM.

        Args:
            dag: The causal DAG structure
            coefficients: Dict mapping (parent, child) to coefficient β_parent,child
            noise_variances: Array of noise variances σ_i² for each node

        Raises:
            ValueError: If coefficients don't match edges or variances invalid
        """
        self._dag = dag
        self._coefficients = dict(coefficients)
        self._noise_variances = np.array(noise_variances, dtype=np.float64)

        # Validation
        self._validate()

        # Cached matrices
        self._covariance_matrix: Optional[np.ndarray] = None
        self._precision_matrix: Optional[np.ndarray] = None
        self._coefficient_matrix: Optional[np.ndarray] = None

    def _validate(self) -> None:
        """Validate the SEM parameters."""
        d = self._dag.num_nodes()

        # Check noise variances
        if len(self._noise_variances) != d:
            raise ValueError(
                f"Expected {d} noise variances, got {len(self._noise_variances)}"
            )
        if np.any(self._noise_variances <= 0):
            raise ValueError("All noise variances must be positive")

        # Check coefficients match edges
        edges = self._dag.edges
        for edge in edges:
            if edge not in self._coefficients:
                raise ValueError(f"Missing coefficient for edge {edge}")

        for edge in self._coefficients:
            if edge not in edges:
                raise ValueError(f"Coefficient for non-existent edge {edge}")

    @property
    def dag(self) -> DAG:
        """Return the underlying DAG."""
        return self._dag

    @property
    def coefficients(self) -> Dict[Tuple[int, int], float]:
        """Return a copy of the coefficients dictionary."""
        return dict(self._coefficients)

    @property
    def noise_variances(self) -> np.ndarray:
        """Return a copy of the noise variances array."""
        return self._noise_variances.copy()

    def num_nodes(self) -> int:
        """Return the number of nodes."""
        return self._dag.num_nodes()

    def get_coefficient(self, parent: int, child: int) -> float:
        """
        Get the coefficient for a specific edge.

        Args:
            parent: Parent node
            child: Child node

        Returns:
            The coefficient β_parent,child

        Raises:
            KeyError: If edge doesn't exist
        """
        edge = (parent, child)
        if edge not in self._coefficients:
            raise KeyError(f"Edge {parent} -> {child} does not exist")
        return self._coefficients[edge]

    def coefficient_matrix(self) -> np.ndarray:
        """
        Return the coefficient matrix B.

        B[j, i] = β_ji if edge j → i exists, else 0.
        So X = B^T X + ε, or equivalently X = (I - B^T)^{-1} ε

        Returns:
            d×d coefficient matrix
        """
        if self._coefficient_matrix is not None:
            return self._coefficient_matrix.copy()

        d = self._dag.num_nodes()
        B = np.zeros((d, d), dtype=np.float64)

        for (parent, child), coef in self._coefficients.items():
            B[parent, child] = coef

        self._coefficient_matrix = B
        return B.copy()

    def covariance_matrix(self) -> np.ndarray:
        """
        Compute the covariance matrix analytically.

        For linear SEM: X = (I - B^T)^{-1} ε
        where B[j,i] = β_ji.

        Cov(X) = (I - B^T)^{-1} D (I - B^T)^{-T}
        where D = diag(σ_1², ..., σ_d²)

        Returns:
            d×d covariance matrix Σ
        """
        if self._covariance_matrix is not None:
            return self._covariance_matrix.copy()

        d = self._dag.num_nodes()
        B = self.coefficient_matrix()
        D = np.diag(self._noise_variances)

        # (I - B^T)^{-1}
        I_minus_BT = np.eye(d) - B.T
        I_minus_BT_inv = linalg.inv(I_minus_BT)

        # Σ = (I - B^T)^{-1} D (I - B^T)^{-T}
        Sigma = I_minus_BT_inv @ D @ I_minus_BT_inv.T

        # Ensure symmetry (numerical precision)
        Sigma = (Sigma + Sigma.T) / 2

        self._covariance_matrix = Sigma
        return Sigma.copy()

    def precision_matrix(self) -> np.ndarray:
        """
        Compute the precision matrix (inverse covariance).

        Θ = Σ^{-1} = (I - B^T)^T D^{-1} (I - B^T)

        Returns:
            d×d precision matrix Θ
        """
        if self._precision_matrix is not None:
            return self._precision_matrix.copy()

        d = self._dag.num_nodes()
        B = self.coefficient_matrix()
        D_inv = np.diag(1.0 / self._noise_variances)

        I_minus_BT = np.eye(d) - B.T

        # Θ = (I - B^T)^T D^{-1} (I - B^T)
        Theta = I_minus_BT.T @ D_inv @ I_minus_BT

        # Ensure symmetry
        Theta = (Theta + Theta.T) / 2

        self._precision_matrix = Theta
        return Theta.copy()

    def sample(
        self,
        n: int,
        random_state: Optional[int] = None
    ) -> np.ndarray:
        """
        Generate n i.i.d. samples from the SEM.

        Uses ancestral sampling in topological order:
        For each node i in topological order:
            X_i = Σ_{j ∈ Pa(i)} β_ji * X_j + ε_i

        Args:
            n: Number of samples to generate
            random_state: Random seed for reproducibility

        Returns:
            n×d data matrix where each row is a sample
        """
        if n < 1:
            raise ValueError(f"Number of samples must be positive, got {n}")

        rng = np.random.default_rng(random_state)
        d = self._dag.num_nodes()

        # Initialize data matrix
        X = np.zeros((n, d), dtype=np.float64)

        # Sample noise
        noise = rng.normal(0, 1, size=(n, d))
        noise = noise * np.sqrt(self._noise_variances)  # Scale by std dev

        # Generate in topological order
        topo_order = self._dag.topological_sort()

        for i in topo_order:
            # Start with noise
            X[:, i] = noise[:, i]

            # Add weighted parent contributions
            for parent in self._dag.parents(i):
                coef = self._coefficients[(parent, i)]
                X[:, i] += coef * X[:, parent]

        return X

    def min_edge_coefficient(self) -> float:
        """Return the minimum absolute edge coefficient (ε)."""
        if not self._coefficients:
            return float('inf')
        return min(abs(c) for c in self._coefficients.values())

    def max_edge_coefficient(self) -> float:
        """Return the maximum absolute edge coefficient (B)."""
        if not self._coefficients:
            return 0.0
        return max(abs(c) for c in self._coefficients.values())

    def min_noise_variance(self) -> float:
        """Return the minimum noise variance (σ_min²)."""
        return float(np.min(self._noise_variances))

    def max_noise_variance(self) -> float:
        """Return the maximum noise variance (σ_max²)."""
        return float(np.max(self._noise_variances))

    def is_well_conditioned(
        self,
        beta_min: float = 0.1,
        sigma_min: float = 0.1,
        max_eigenvalue: float = 100.0
    ) -> bool:
        """
        Check if the SEM is spectrally well-conditioned.

        Per Proposition 7.1, a well-conditioned SEM has:
        - |β_ij| ≥ beta_min for all edges
        - σ_i² ≥ sigma_min for all nodes
        - λ_max(Σ) ≤ max_eigenvalue

        Args:
            beta_min: Minimum edge coefficient threshold
            sigma_min: Minimum noise variance threshold
            max_eigenvalue: Maximum covariance eigenvalue threshold

        Returns:
            True if all conditions are satisfied
        """
        if self.min_edge_coefficient() < beta_min:
            return False
        if self.min_noise_variance() < sigma_min:
            return False

        Sigma = self.covariance_matrix()
        eigenvalues = linalg.eigvalsh(Sigma)
        if np.max(eigenvalues) > max_eigenvalue:
            return False

        return True

    def conditional_variance(self, node: int, conditioning_set: set) -> float:
        """
        Compute Var(X_node | X_S) where S is the conditioning set.

        Uses the formula:
        Var(X_i | X_S) = Σ_ii - Σ_{i,S} Σ_{S,S}^{-1} Σ_{S,i}

        Args:
            node: Node to compute conditional variance for
            conditioning_set: Set of nodes to condition on

        Returns:
            Conditional variance
        """
        Sigma = self.covariance_matrix()

        if not conditioning_set:
            return Sigma[node, node]

        S = sorted(conditioning_set)

        # Extract submatrices
        Sigma_ii = Sigma[node, node]
        Sigma_iS = Sigma[node, S]
        Sigma_SS = Sigma[np.ix_(S, S)]
        Sigma_Si = Sigma[S, node]

        # Compute conditional variance
        try:
            Sigma_SS_inv = linalg.inv(Sigma_SS)
            cond_var = Sigma_ii - Sigma_iS @ Sigma_SS_inv @ Sigma_Si
        except linalg.LinAlgError:
            # Use pseudo-inverse for numerical stability
            Sigma_SS_pinv = linalg.pinv(Sigma_SS)
            cond_var = Sigma_ii - Sigma_iS @ Sigma_SS_pinv @ Sigma_Si

        return max(0.0, cond_var)  # Ensure non-negative

    @classmethod
    def random(
        cls,
        dag: DAG,
        beta_range: Tuple[float, float] = (0.3, 0.6),
        sigma_range: Tuple[float, float] = (1.0, 1.0),
        random_state: Optional[int] = None,
        sign_distribution: str = 'random'
    ) -> 'LinearGaussianSEM':
        """
        Create a random Linear Gaussian SEM for a given DAG.

        Args:
            dag: The causal DAG structure
            beta_range: (min, max) for edge coefficient magnitudes
            sigma_range: (min, max) for noise variances
            random_state: Random seed for reproducibility
            sign_distribution: 'random', 'positive', or 'negative' for coefficient signs

        Returns:
            A new LinearGaussianSEM with random parameters
        """
        rng = np.random.default_rng(random_state)
        d = dag.num_nodes()

        # Generate edge coefficients
        coefficients = {}
        beta_min, beta_max = beta_range

        for edge in dag.edges:
            magnitude = rng.uniform(beta_min, beta_max)

            if sign_distribution == 'random':
                sign = rng.choice([-1, 1])
            elif sign_distribution == 'positive':
                sign = 1
            elif sign_distribution == 'negative':
                sign = -1
            else:
                raise ValueError(f"Unknown sign_distribution: {sign_distribution}")

            coefficients[edge] = sign * magnitude

        # Generate noise variances
        sigma_min, sigma_max = sigma_range
        if sigma_min == sigma_max:
            noise_variances = np.full(d, sigma_min)
        else:
            noise_variances = rng.uniform(sigma_min, sigma_max, size=d)

        return cls(dag, coefficients, noise_variances)

    @classmethod
    def from_matrices(
        cls,
        coefficient_matrix: np.ndarray,
        noise_variances: np.ndarray
    ) -> 'LinearGaussianSEM':
        """
        Create a LinearGaussianSEM from a coefficient matrix.

        Args:
            coefficient_matrix: d×d matrix B where B[j,i] = β_ji
            noise_variances: Array of noise variances

        Returns:
            A new LinearGaussianSEM
        """
        d = coefficient_matrix.shape[0]

        # Extract edges from non-zero entries
        edges = []
        coefficients = {}

        for j in range(d):
            for i in range(d):
                if coefficient_matrix[j, i] != 0:
                    edges.append((j, i))
                    coefficients[(j, i)] = coefficient_matrix[j, i]

        dag = DAG(d, edges)
        return cls(dag, coefficients, noise_variances)

    def copy(self) -> 'LinearGaussianSEM':
        """Create a deep copy of this SEM."""
        return LinearGaussianSEM(
            self._dag.copy(),
            dict(self._coefficients),
            self._noise_variances.copy()
        )

    def with_modified_coefficient(
        self,
        parent: int,
        child: int,
        new_coefficient: float
    ) -> 'LinearGaussianSEM':
        """
        Create a copy with a modified edge coefficient.

        Args:
            parent: Parent node
            child: Child node
            new_coefficient: New coefficient value

        Returns:
            New SEM with modified coefficient
        """
        new_coefficients = dict(self._coefficients)
        edge = (parent, child)

        if edge not in new_coefficients:
            raise KeyError(f"Edge {parent} -> {child} does not exist")

        new_coefficients[edge] = new_coefficient

        return LinearGaussianSEM(
            self._dag.copy(),
            new_coefficients,
            self._noise_variances.copy()
        )

    def interventional_distribution(
        self,
        intervened_nodes: Dict[int, float]
    ) -> 'LinearGaussianSEM':
        """
        Create the interventional distribution do(X_i = v_i).

        This creates a new SEM where intervened nodes have their
        incoming edges removed and variance set to zero (point mass).

        Args:
            intervened_nodes: Dict mapping node to intervention value

        Returns:
            New SEM representing the interventional distribution
        """
        d = self._dag.num_nodes()

        # Remove incoming edges to intervened nodes
        new_edges = [
            e for e in self._dag.edges
            if e[1] not in intervened_nodes
        ]

        new_dag = DAG(d, new_edges)

        # Copy coefficients for remaining edges
        new_coefficients = {
            e: c for e, c in self._coefficients.items()
            if e[1] not in intervened_nodes
        }

        # Set noise variance to essentially zero for intervened nodes
        # (In practice, we'd handle this differently for sampling)
        new_variances = self._noise_variances.copy()
        for node in intervened_nodes:
            new_variances[node] = 1e-10  # Near-zero variance

        return LinearGaussianSEM(new_dag, new_coefficients, new_variances)

    def __repr__(self) -> str:
        """String representation."""
        return (
            f"LinearGaussianSEM(nodes={self.num_nodes()}, "
            f"edges={self._dag.num_edges()})"
        )

    def __str__(self) -> str:
        """Human-readable string representation."""
        lines = [f"LinearGaussianSEM with {self.num_nodes()} nodes:"]
        lines.append(f"  Edges: {self._dag.num_edges()}")
        lines.append(f"  Min |β|: {self.min_edge_coefficient():.4f}")
        lines.append(f"  Max |β|: {self.max_edge_coefficient():.4f}")
        lines.append(f"  Noise σ² range: [{self.min_noise_variance():.4f}, "
                     f"{self.max_noise_variance():.4f}]")
        return "\n".join(lines)


def compute_implied_covariance(
    dag: DAG,
    coefficients: Dict[Tuple[int, int], float],
    noise_variances: np.ndarray
) -> np.ndarray:
    """
    Compute the implied covariance matrix for given SEM parameters.

    This is a convenience function that doesn't require creating
    a full LinearGaussianSEM object.

    Args:
        dag: The causal DAG
        coefficients: Edge coefficients
        noise_variances: Node noise variances

    Returns:
        Implied covariance matrix
    """
    sem = LinearGaussianSEM(dag, coefficients, noise_variances)
    return sem.covariance_matrix()


def marginal_variance(sem: LinearGaussianSEM, node: int) -> float:
    """Return the marginal variance of a node."""
    Sigma = sem.covariance_matrix()
    return Sigma[node, node]


def marginal_covariance(sem: LinearGaussianSEM, node1: int, node2: int) -> float:
    """Return the marginal covariance between two nodes."""
    Sigma = sem.covariance_matrix()
    return Sigma[node1, node2]
