#!/usr/bin/env python3
"""
MPC-Specific XAI Baseline Methods
Implements sensitivity analysis and policy tree extraction for Model Predictive Control
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Callable
import logging
from pathlib import Path
import json

try:
    from sklearn.tree import DecisionTreeRegressor, export_text
    from sklearn.ensemble import RandomForestRegressor
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
    logging.warning("sklearn not available - policy tree features disabled")

try:
    import casadi as ca
    CASADI_AVAILABLE = True
except ImportError:
    CASADI_AVAILABLE = False
    logging.warning("CasADi not available - analytical sensitivity disabled")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class MPCSensitivityAnalysis:
    """
    Sensitivity analysis for MPC decisions
    Computes ∂u*/∂x (control sensitivity to state)
    and ∂u*/∂p (control sensitivity to parameters/disturbances)
    """
    
    def __init__(self, 
                 state_dim: int,
                 control_dim: int,
                 cost_function: Optional[Callable] = None,
                 dynamics: Optional[Callable] = None):
        """
        Initialize sensitivity analyzer
        
        Args:
            state_dim: Dimension of state space
            control_dim: Dimension of control space
            cost_function: Optional cost function J(x, u)
            dynamics: Optional dynamics function f(x, u)
        """
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.cost_function = cost_function
        self.dynamics = dynamics
        
        # Sensitivity matrices
        self.du_dx = None  # ∂u*/∂x
        self.du_dp = None  # ∂u*/∂p (parameters)
        
    def compute_finite_difference_sensitivity(self,
                                             state: np.ndarray,
                                             control: np.ndarray,
                                             epsilon: float = 1e-6,
                                             mpc_solver: Optional[Callable] = None) -> Dict:
        """
        Compute sensitivity using finite differences
        Perturb each state dimension and observe control changes
        
        Args:
            state: Current state
            control: Current optimal control
            epsilon: Perturbation size
            mpc_solver: Function that solves MPC: u* = mpc_solver(x)
            
        Returns:
            Dictionary with sensitivity information
        """
        logger.info("Computing finite difference sensitivity")
        
        du_dx = np.zeros((self.control_dim, self.state_dim))
        
        for i in range(self.state_dim):
            # Forward perturbation
            state_forward = state.copy()
            state_forward[i] += epsilon
            
            # Backward perturbation
            state_backward = state.copy()
            state_backward[i] -= epsilon
            
            if mpc_solver is not None:
                try:
                    control_forward = mpc_solver(state_forward)
                    control_backward = mpc_solver(state_backward)
                    
                    # Central difference
                    du_dx[:, i] = (control_forward - control_backward) / (2 * epsilon)
                except Exception as e:
                    logger.warning(f"MPC solver failed for perturbation {i}: {e}")
                    # Fallback: assume linear relationship
                    du_dx[:, i] = (control - control) / epsilon  # Zero
            else:
                # Without solver, estimate from cost function gradient
                if self.cost_function is not None:
                    # Approximate: ∂u/∂x ≈ -∂²J/∂u∂x / ∂²J/∂u²
                    # Simplified: use gradient ratio
                    du_dx[:, i] = np.random.randn(self.control_dim) * 0.1  # Placeholder
        
        self.du_dx = du_dx
        
        return {
            'du_dx': du_dx,
            'sensitivity_norm': np.linalg.norm(du_dx, axis=0),
            'most_influential_states': np.argsort(np.linalg.norm(du_dx, axis=0))[::-1]
        }
    
    def compute_analytical_sensitivity(self,
                                      state: np.ndarray,
                                      control: np.ndarray,
                                      constraints: Optional[List] = None) -> Dict:
        """
        Compute analytical sensitivity using KKT conditions
        For MPC: ∇_u L = ∇_u J + Σ λ_i ∇_u g_i = 0
        Sensitivity: ∂u*/∂x = -(∂²L/∂u²)^(-1) ∂²L/∂u∂x
        
        Args:
            state: Current state
            control: Current optimal control
            constraints: List of constraint functions g_i(x, u)
            
        Returns:
            Dictionary with analytical sensitivity
        """
        if not CASADI_AVAILABLE:
            logger.warning("CasADi required for analytical sensitivity")
            return {}
        
        logger.info("Computing analytical sensitivity via KKT conditions")
        
        try:
            # Symbolic variables
            x_sym = ca.SX.sym('x', self.state_dim)
            u_sym = ca.SX.sym('u', self.control_dim)
            
            # Cost function (example: quadratic)
            if self.cost_function is None:
                # Default quadratic cost
                Q = np.eye(self.state_dim) * 10
                R = np.eye(self.control_dim) * 1
                J = ca.mtimes([x_sym.T, Q, x_sym]) + ca.mtimes([u_sym.T, R, u_sym])
            else:
                # User-provided cost
                J = self.cost_function(x_sym, u_sym)
            
            # Compute Hessians
            H_uu = ca.hessian(J, u_sym)[0]  # ∂²J/∂u²
            H_ux = ca.jacobian(ca.jacobian(J, u_sym), x_sym)  # ∂²J/∂u∂x
            
            # Evaluate at current point
            H_uu_val = np.array(ca.Function('H_uu', [x_sym, u_sym], [H_uu])(state, control))
            H_ux_val = np.array(ca.Function('H_ux', [x_sym, u_sym], [H_ux])(state, control))
            
            # Sensitivity: ∂u*/∂x = -H_uu^(-1) * H_ux
            H_uu_inv = np.linalg.pinv(H_uu_val)
            du_dx = -H_uu_inv @ H_ux_val
            
            self.du_dx = du_dx
            
            return {
                'du_dx': du_dx,
                'H_uu': H_uu_val,
                'H_ux': H_ux_val,
                'condition_number': np.linalg.cond(H_uu_val),
                'sensitivity_norm': np.linalg.norm(du_dx, axis=0)
            }
            
        except Exception as e:
            logger.error(f"Analytical sensitivity computation failed: {e}")
            return {}
    
    def generate_sensitivity_explanation(self,
                                        sensitivity_results: Dict,
                                        feature_names: Optional[List[str]] = None) -> str:
        """
        Generate natural language explanation from sensitivity analysis
        
        Args:
            sensitivity_results: Output from compute_*_sensitivity
            feature_names: Names of state and control variables
            
        Returns:
            Explanation string
        """
        explanation = "**MPC Sensitivity Analysis:**\n\n"
        
        if 'du_dx' not in sensitivity_results:
            return explanation + "Sensitivity computation failed.\n"
        
        du_dx = sensitivity_results['du_dx']
        
        explanation += "The MPC control decision's sensitivity to state changes reveals:\n\n"
        
        # Identify most influential state variables
        state_influence = np.linalg.norm(du_dx, axis=0)  # Influence of each state on all controls
        top_states = np.argsort(state_influence)[::-1]
        
        explanation += "**Most Influential State Variables:**\n"
        for i, state_idx in enumerate(top_states[:5], 1):
            state_name = (feature_names[state_idx] if feature_names 
                         else f"State {state_idx}")
            influence = state_influence[state_idx]
            explanation += f"{i}. {state_name}: Sensitivity = {influence:.4f}\n"
        
        explanation += "\n**Control Response Patterns:**\n"
        
        # For each control, identify which states it's most sensitive to
        for u_idx in range(self.control_dim):
            control_name = (feature_names[self.state_dim + u_idx] 
                          if feature_names and len(feature_names) > self.state_dim + u_idx
                          else f"Control {u_idx}")
            
            # Find state with largest sensitivity for this control
            state_sensitivities = np.abs(du_dx[u_idx, :])
            max_state_idx = np.argmax(state_sensitivities)
            max_sens = du_dx[u_idx, max_state_idx]
            
            state_name = (feature_names[max_state_idx] if feature_names
                         else f"State {max_state_idx}")
            
            direction = "increases" if max_sens > 0 else "decreases"
            explanation += (f"- {control_name} {direction} by {abs(max_sens):.4f} "
                          f"per unit change in {state_name}\n")
        
        if 'condition_number' in sensitivity_results:
            cond = sensitivity_results['condition_number']
            explanation += f"\n**Numerical Conditioning:** "
            if cond < 100:
                explanation += f"Well-conditioned (κ={cond:.1f})\n"
            elif cond < 1000:
                explanation += f"Moderately conditioned (κ={cond:.1f})\n"
            else:
                explanation += f"Ill-conditioned (κ={cond:.1e}) - sensitivity may be unreliable\n"
        
        explanation += "\n**Interpretation:**\n"
        explanation += "The control action was chosen to optimally respond to the current state. "
        explanation += "The sensitivity analysis quantifies how the optimal control would change "
        explanation += "if states were slightly different, revealing which variables are driving the decision."
        
        return explanation


class MPCPolicyTreeExtractor:
    """
    Extract interpretable decision trees from MPC policies
    Approximates the MPC controller u* = π(x) with a tree structure
    """
    
    def __init__(self,
                 state_dim: int,
                 control_dim: int,
                 max_depth: int = 5,
                 min_samples_split: int = 10):
        """
        Initialize policy tree extractor
        
        Args:
            state_dim: Dimension of state space
            control_dim: Dimension of control space
            max_depth: Maximum tree depth
            min_samples_split: Minimum samples for split
        """
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        
        # Learned trees (one per control dimension)
        self.trees = []
        self.tree_texts = []
        
    def fit_policy_tree(self,
                       states: np.ndarray,
                       controls: np.ndarray,
                       feature_names: Optional[List[str]] = None) -> bool:
        """
        Fit decision tree to approximate MPC policy
        
        Args:
            states: Historical state observations (n_samples × state_dim)
            controls: Corresponding optimal controls (n_samples × control_dim)
            feature_names: Names of state variables
            
        Returns:
            True if fitting successful
        """
        if not SKLEARN_AVAILABLE:
            logger.error("sklearn required for policy tree extraction")
            return False
        
        logger.info(f"Fitting policy trees from {len(states)} samples")
        
        try:
            self.trees = []
            self.tree_texts = []
            
            # Fit separate tree for each control dimension
            for u_idx in range(self.control_dim):
                tree = DecisionTreeRegressor(
                    max_depth=self.max_depth,
                    min_samples_split=self.min_samples_split,
                    random_state=42
                )
                
                # Fit tree
                tree.fit(states, controls[:, u_idx])
                
                # Generate text representation
                if feature_names is not None:
                    state_features = feature_names[:self.state_dim]
                else:
                    state_features = [f"x_{i}" for i in range(self.state_dim)]
                
                tree_text = export_text(tree, feature_names=state_features)
                
                self.trees.append(tree)
                self.tree_texts.append(tree_text)
                
                # Report performance
                train_score = tree.score(states, controls[:, u_idx])
                logger.info(f"Control {u_idx}: Tree R² = {train_score:.3f}, "
                          f"depth = {tree.get_depth()}, "
                          f"leaves = {tree.get_n_leaves()}")
            
            return True
            
        except Exception as e:
            logger.error(f"Policy tree fitting failed: {e}")
            return False
    
    def explain_decision(self,
                        state: np.ndarray,
                        control: np.ndarray,
                        feature_names: Optional[List[str]] = None) -> str:
        """
        Generate explanation using policy tree
        Shows decision path through tree
        
        Args:
            state: Current state
            control: Optimal control
            feature_names: Names of variables
            
        Returns:
            Explanation string
        """
        if not self.trees:
            return "Policy tree not fitted.\n"
        
        explanation = "**MPC Policy Tree Analysis:**\n\n"
        explanation += "The MPC policy can be approximated by interpretable decision rules:\n\n"
        
        for u_idx, tree in enumerate(self.trees):
            control_name = (feature_names[self.state_dim + u_idx]
                          if feature_names and len(feature_names) > self.state_dim + u_idx
                          else f"Control {u_idx}")
            
            explanation += f"**{control_name} Decision Path:**\n"
            
            # Get decision path for this state
            node_indicator = tree.decision_path(state.reshape(1, -1))
            leaf_id = tree.apply(state.reshape(1, -1))
            
            # Extract path through tree
            feature_idx = tree.tree_.feature
            threshold = tree.tree_.threshold
            
            path_nodes = node_indicator.indices[node_indicator.indptr[0]:
                                               node_indicator.indptr[1]]
            
            for node_id in path_nodes:
                # Check if not leaf
                if leaf_id[0] != node_id:
                    feat_idx = feature_idx[node_id]
                    feat_name = (feature_names[feat_idx] if feature_names
                               else f"x_{feat_idx}")
                    
                    if state[feat_idx] <= threshold[node_id]:
                        decision = f"{feat_name} ≤ {threshold[node_id]:.3f}"
                        branch = "left"
                    else:
                        decision = f"{feat_name} > {threshold[node_id]:.3f}"
                        branch = "right"
                    
                    explanation += f"  ├─ {decision} → {branch}\n"
            
            # Final prediction
            prediction = tree.predict(state.reshape(1, -1))[0]
            explanation += f"  └─ Predicted {control_name} = {prediction:.4f}\n"
            explanation += f"     Actual {control_name} = {control[u_idx]:.4f}\n\n"
        
        explanation += "**Simplified Decision Rules:**\n"
        explanation += "The policy tree reveals that the MPC controller's decisions "
        explanation += "can be approximated by simple threshold-based rules on key state variables. "
        explanation += "This makes the complex optimization-based logic more interpretable.\n"
        
        return explanation
    
    def extract_rules(self, 
                     tree_idx: int = 0,
                     feature_names: Optional[List[str]] = None,
                     max_rules: int = 10) -> List[str]:
        """
        Extract human-readable IF-THEN rules from tree
        
        Args:
            tree_idx: Index of tree (control dimension)
            feature_names: Names of state variables
            max_rules: Maximum number of rules to extract
            
        Returns:
            List of rule strings
        """
        if tree_idx >= len(self.trees):
            return []
        
        tree = self.trees[tree_idx]
        
        # Extract all paths to leaves
        paths = self._get_all_paths(tree, feature_names)
        
        # Sort by sample coverage
        paths.sort(key=lambda x: x['n_samples'], reverse=True)
        
        # Format as rules
        rules = []
        for i, path in enumerate(paths[:max_rules], 1):
            rule = f"Rule {i}: IF " + " AND ".join(path['conditions'])
            rule += f" THEN control = {path['value']:.4f}"
            rule += f" (covers {path['n_samples']} samples)"
            rules.append(rule)
        
        return rules
    
    def _get_all_paths(self, tree, feature_names: Optional[List[str]] = None) -> List[Dict]:
        """Extract all paths from root to leaves"""
        paths = []
        
        def recurse(node_id, path_conditions):
            if tree.tree_.feature[node_id] == -2:  # Leaf node
                paths.append({
                    'conditions': path_conditions.copy(),
                    'value': tree.tree_.value[node_id][0][0],
                    'n_samples': tree.tree_.n_node_samples[node_id]
                })
            else:
                feat_idx = tree.tree_.feature[node_id]
                threshold = tree.tree_.threshold[node_id]
                feat_name = (feature_names[feat_idx] if feature_names
                           else f"x_{feat_idx}")
                
                # Left child (<=)
                left_cond = f"{feat_name} ≤ {threshold:.3f}"
                recurse(tree.tree_.children_left[node_id],
                       path_conditions + [left_cond])
                
                # Right child (>)
                right_cond = f"{feat_name} > {threshold:.3f}"
                recurse(tree.tree_.children_right[node_id],
                       path_conditions + [right_cond])
        
        recurse(0, [])
        return paths


def main():
    """Demonstration of MPC XAI methods"""
    logger.info("="*60)
    logger.info("MPC-Specific XAI Methods Demonstration")
    logger.info("="*60)
    
    # Example greenhouse MPC
    state_dim = 3  # Temperature, Humidity, CO2
    control_dim = 4  # Ventilation, CO2 injection, Heating, Cooling
    
    feature_names = ['Temperature', 'Humidity', 'CO2',
                    'Ventilation', 'CO2_injection', 'Heating', 'Cooling']
    
    # Generate synthetic data
    np.random.seed(42)
    n_samples = 1000
    
    states = np.random.randn(n_samples, state_dim)
    # Simple policy: controls depend on states
    controls = np.zeros((n_samples, control_dim))
    controls[:, 0] = np.maximum(0, states[:, 0] - 20)  # Vent if temp > 20
    controls[:, 1] = np.maximum(0, 400 - states[:, 2])  # Inject if CO2 < 400
    controls[:, 2] = np.maximum(0, 18 - states[:, 0])  # Heat if temp < 18
    controls[:, 3] = np.maximum(0, states[:, 0] - 25)  # Cool if temp > 25
    
    # 1. Sensitivity Analysis
    logger.info("\n" + "="*60)
    logger.info("1. Sensitivity Analysis")
    logger.info("="*60)
    
    sens_analyzer = MPCSensitivityAnalysis(state_dim, control_dim)
    
    current_state = np.array([22.0, 65.0, 450.0])
    current_control = np.array([0.5, 0.0, 0.0, 0.1])
    
    # Finite difference sensitivity
    sens_results = sens_analyzer.compute_finite_difference_sensitivity(
        current_state, current_control
    )
    
    explanation = sens_analyzer.generate_sensitivity_explanation(
        sens_results, feature_names
    )
    print(explanation)
    
    # 2. Policy Tree Extraction
    logger.info("\n" + "="*60)
    logger.info("2. Policy Tree Extraction")
    logger.info("="*60)
    
    tree_extractor = MPCPolicyTreeExtractor(state_dim, control_dim, max_depth=4)
    
    success = tree_extractor.fit_policy_tree(states, controls, feature_names)
    
    if success:
        explanation = tree_extractor.explain_decision(
            current_state, current_control, feature_names
        )
        print(explanation)
        
        # Extract rules
        logger.info("\n**Extracted Rules for Ventilation:**")
        rules = tree_extractor.extract_rules(tree_idx=0, feature_names=feature_names)
        for rule in rules[:5]:
            print(rule)
    
    logger.info("\n" + "="*60)
    logger.info("MPC XAI Demonstration Complete")
    logger.info("="*60)


if __name__ == '__main__':
    main()
