"""
Misspecified data post-processors used in the experiments.
All scenarios build on IIDSimulation and then apply an additional transformation:

1. pnl: IIDSimulation followed by a post-nonlinear transformation (x^3 by default).
2. lingam: IIDSimulation with the LiNGAM parameter settings handled upstream.
3. confounded: IIDSimulation plus a confounder-style perturbation (rho=0.2 by default).
4. measure_err: IIDSimulation with measurement noise (gamma=0.8 by default).
5. unfaithful: IIDSimulation with edge cancellations to induce unfaithfulness.
6. timino: IIDSimulation with autoregressive temporal dependencies.
"""

import numpy as np
import random


class MisspecificationPostProcessor:
    """Base class for post-processing misspecifications"""
    
    def __init__(self, seed=42):
        self.seed = seed
        self.set_seeds()
    
    def set_seeds(self):
        """Set random seeds for reproducibility"""
        random.seed(self.seed)
        np.random.seed(self.seed)


class ConfoundedPostProcessor(MisspecificationPostProcessor):
    """Add confounding effects to existing data"""
    
    def __init__(self, rho=0.2, noise_scale=1.0, seed=42):
        super().__init__(seed)
        self.rho = rho
        self.noise_scale = noise_scale
    
    def apply(self, X, adjacency_matrix):
        """Add confounding effects to data"""
        n, d = X.shape
        
        # Add confounders (simple approach: add correlated noise)
        for i in range(d):
            for j in range(i+1, d):
                if np.random.random() < self.rho:
                    # Add common confounder effect
                    confounder_noise = np.random.normal(0, self.noise_scale, n)
                    X[:, i] += 0.5 * confounder_noise
                    X[:, j] += 0.5 * confounder_noise
        
        return X





class MeasureErrorPostProcessor(MisspecificationPostProcessor):
    """Add measurement error to observations"""
    
    def __init__(self, gamma=0.8, seed=42):
        super().__init__(seed)
        self.gamma = gamma
    
    def apply(self, X, adjacency_matrix):
        """Add measurement error to data"""
        n, d = X.shape
        X_std = np.std(X, axis=0)
        
        for j in range(d):
            error_scale = self.gamma * X_std[j]
            error = np.random.normal(0, error_scale, n)
            X[:, j] += error
        
        return X


class TiminoPostProcessor(MisspecificationPostProcessor):
    """Add temporal dependencies (TiMINo model)"""
    
    def __init__(self, seed=42):
        super().__init__(seed)
    
    def apply(self, X, adjacency_matrix):
        """Add temporal dependencies (lagged effects)"""
        n, d = X.shape
        linear_coeffs = np.random.uniform(-0.8, 0.8, d)
        
        # Add lagged effects: X(t) = X(t) + c * X(t-1)
        for t in range(1, n):
            X[t] += linear_coeffs * X[t-1]
        
        return X


class UnfaithfulPostProcessor(MisspecificationPostProcessor):
    """Create unfaithful effects by cancellation"""
    
    def __init__(self, p_unfaithful=0.3, seed=42):
        super().__init__(seed)
        self.p_unfaithful = p_unfaithful
    
    def apply(self, X, adjacency_matrix):
        """Create unfaithful effects by path cancellation"""
        n, d = X.shape
        
        # Find edges that can be made unfaithful
        for i in range(d):
            for j in range(i+1, d):
                if adjacency_matrix[i, j] == 1 and np.random.random() < self.p_unfaithful:
                    # Add canceling effect to make edge unfaithful
                    cancel_strength = np.random.uniform(0.5, 1.0)
                    X[:, j] -= cancel_strength * X[:, i]
        
        return X


class PNLPostProcessor(MisspecificationPostProcessor):
    """Apply post-nonlinear transformation"""
    
    def __init__(self, exponent=3.0, seed=42):
        super().__init__(seed)
        self.exponent = exponent
    
    def apply(self, X, adjacency_matrix):
        """Apply post-nonlinear transformation (paper: x^3)"""
        # Apply power transformation: X^exponent
        # Handle negative values properly for odd exponents
        if self.exponent % 2 == 1:  # Odd exponent
            X_transformed = np.sign(X) * np.power(np.abs(X), self.exponent)
        else:  # Even exponent
            X_transformed = np.power(np.abs(X), self.exponent)
        
        return X_transformed


def create_postprocessor(scenario, **kwargs):
    """Factory function to create appropriate post-processor"""
    
    processors = {
        'confounded': ConfoundedPostProcessor,
        'measure_err': MeasureErrorPostProcessor,
        'timino': TiminoPostProcessor,
        'unfaithful': UnfaithfulPostProcessor,
        'pnl': PNLPostProcessor
    }
    
    if scenario not in processors:
        return None  # vanilla and lingam don't need post-processing
    
    return processors[scenario](**kwargs)


def apply_misspecification(X, adjacency_matrix, scenario, **kwargs):
    """Apply misspecification post-processing to data"""
    
    if scenario == "vanilla":
        # No post-processing for vanilla (pure IIDSimulation)
        return X
        
    if scenario == "lingam":
        # No additional post-processing needed for LiNGAM (handled by IIDSimulation linear_rate + linear_sem_type)
        return X
    
    processor = create_postprocessor(scenario, **kwargs)
    if processor is not None:
        return processor.apply(X, adjacency_matrix)
    
    return X 
