#!/usr/bin/env python3
"""
Nonlinear Systems Implementation
Stability analyis of nonlinear systems
"""

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import solve_discrete_are
from scipy import stats
import pandas as pd
from dataclasses import dataclass
from typing import Optional, Callable, Tuple
import warnings

# Set style for journal-quality plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 16,
    'lines.linewidth': 2,
    'axes.grid': True,
    'grid.alpha': 0.3
})

warnings.filterwarnings('ignore')

# Import necessary classes from the original code
from dataclasses import dataclass
from typing import Type

@dataclass
class NeuralControllerConfig:
    hidden_size: int = 32
    num_layers: int = 2
    activation: Type[nn.Module] = nn.Tanh

@dataclass
class TrainingConfig:
    episodes: int = 100
    lr: float = 0.001

@dataclass
class NonlinearRegime:
    """Nonlinear regime with quadratic costs"""
    name: str
    dynamics: 'NonlinearDynamics'  # Use string literal for forward reference
    Q: np.ndarray  # State cost matrix
    R: np.ndarray  # Control cost matrix


# =============================================
# ROBUST NONLINEAR DYNAMICS WITH BETTER CONDITIONING
# =============================================

class NonlinearDynamics:
    """Base class with robust linearization"""

    def __init__(self, state_dim: int, control_dim: int):
        self.state_dim = state_dim
        self.control_dim = control_dim

    def dynamics(self, state: np.ndarray, action: np.ndarray, dt: float = 0.1) -> np.ndarray:
        raise NotImplementedError

    def linearize(self, state: np.ndarray, action: np.ndarray, dt: float = 0.1) -> Tuple[np.ndarray, np.ndarray]:
        """Robust numerical linearization with better conditioning"""
        eps = 1e-6

        # Compute nominal next state
        x_next = self.dynamics(state, action, dt)

        # Numerical differentiation for A matrix
        A = np.zeros((self.state_dim, self.state_dim))
        for i in range(self.state_dim):
            state_plus = state.copy()
            state_plus[i] += eps
            x_plus = self.dynamics(state_plus, action, dt)
            A[:, i] = (x_plus - x_next) / eps

        # Numerical differentiation for B matrix
        B = np.zeros((self.state_dim, self.control_dim))
        for i in range(self.control_dim):
            action_plus = action.copy()
            action_plus[i] += eps
            x_plus = self.dynamics(state, action_plus, dt)
            B[:, i] = (x_plus - x_next) / eps

        # Condition the matrices for better numerical stability
        A = self._condition_matrix(A)
        B = self._condition_matrix(B)

        return A, B

    def _condition_matrix(self, M: np.ndarray, max_eigenval: float = 0.95) -> np.ndarray:
        """Condition matrix for stability"""
        if M.shape[0] == M.shape[1]:  # Square matrix (A)
            eigenvals = np.linalg.eigvals(M)
            if np.any(np.abs(eigenvals) >= 1.0):
                # Scale down to ensure stability
                scale = max_eigenval / np.max(np.abs(eigenvals))
                M = M * scale
        return M

class RobustPendulumDynamics(NonlinearDynamics):
    """Pendulum with better conditioning and smaller time step"""

    def __init__(self, mass=1.0, length=1.0, damping=0.5, gravity=9.81):  # Increased damping
        super().__init__(state_dim=2, control_dim=1)
        self.m = mass
        self.l = length
        self.b = damping
        self.g = gravity

    def dynamics(self, state: np.ndarray, action: np.ndarray, dt: float = 0.05) -> np.ndarray:  # Smaller dt
        """More stable pendulum dynamics"""
        theta, theta_dot = state
        torque = np.clip(action[0] if len(action) > 0 else 0.0, -5.0, 5.0)  # Limit torque

        # Use small angle approximation for better linearization near origin
        if abs(theta) < 0.1:
            sin_theta = theta
        else:
            sin_theta = np.sin(theta)

        theta_ddot = (-self.g/self.l * sin_theta
                      - self.b/(self.m*self.l**2) * theta_dot
                      + torque/(self.m*self.l**2))

        # Euler integration with saturation
        theta_next = theta + theta_dot * dt
        theta_dot_next = theta_dot + theta_ddot * dt

        # Keep theta in [-pi, pi]
        theta_next = np.arctan2(np.sin(theta_next), np.cos(theta_next))

        return np.array([theta_next, theta_dot_next])

class RobustCartPoleDynamics(NonlinearDynamics):
    """Cart-pole with linearization around upright position"""

    def __init__(self, cart_mass=1.0, pole_mass=0.1, pole_length=0.25, gravity=9.81):  # Shorter pole
        super().__init__(state_dim=4, control_dim=1)
        self.mc = cart_mass
        self.mp = pole_mass
        self.l = pole_length
        self.g = gravity
        self.total_mass = self.mc + self.mp

    def dynamics(self, state: np.ndarray, action: np.ndarray, dt: float = 0.02) -> np.ndarray:  # Much smaller dt
        """Simplified cart-pole dynamics"""
        x, x_dot, theta, theta_dot = state
        force = np.clip(action[0] if len(action) > 0 else 0.0, -10.0, 10.0)

        # Small angle approximation for better stability
        if abs(theta) < 0.2:
            sin_theta = theta
            cos_theta = 1.0 - 0.5 * theta**2
        else:
            sin_theta = np.sin(theta)
            cos_theta = np.cos(theta)

        # Simplified linearized dynamics near upright position
        denominator = self.total_mass - self.mp * cos_theta**2
        denominator = max(denominator, 0.1)  # Avoid division by zero

        x_ddot = (force + self.mp * self.l * theta_dot**2 * sin_theta
                  - self.mp * self.g * cos_theta * sin_theta) / denominator

        theta_ddot = (self.g * sin_theta - cos_theta * x_ddot) / self.l

        # Integration with smaller time step
        x_next = x + x_dot * dt
        x_dot_next = x_dot + x_ddot * dt
        theta_next = theta + theta_dot * dt
        theta_dot_next = theta_dot + theta_ddot * dt

        # Keep theta small
        theta_next = np.clip(theta_next, -np.pi/4, np.pi/4)

        return np.array([x_next, x_dot_next, theta_next, theta_dot_next])

class DoubleIntegratorDynamics(NonlinearDynamics):
    """Double integrator with mild nonlinearity - much more LQR-friendly"""

    def __init__(self, damping=0.1, nonlinearity=0.1):
        super().__init__(state_dim=2, control_dim=1)
        self.damping = damping
        self.nonlinearity = nonlinearity

    def dynamics(self, state: np.ndarray, action: np.ndarray, dt: float = 0.05) -> np.ndarray:
        """Double integrator with cubic nonlinearity"""
        pos, vel = state
        force = np.clip(action[0] if len(action) > 0 else 0.0, -5.0, 5.0)

        # Double integrator with mild cubic nonlinearity and damping
        pos_dot = vel
        vel_dot = force - self.damping * vel - self.nonlinearity * pos**3

        # Integration
        pos_next = pos + pos_dot * dt
        vel_next = vel + vel_dot * dt

        # Reasonable bounds
        pos_next = np.clip(pos_next, -10, 10)
        vel_next = np.clip(vel_next, -10, 10)

        return np.array([pos_next, vel_next])

class RobustVanDerPolDynamics(NonlinearDynamics):
    """Van der Pol with reduced nonlinearity"""

    def __init__(self, mu=0.5):  # Reduced mu for less aggressive dynamics
        super().__init__(state_dim=2, control_dim=1)
        self.mu = mu

    def dynamics(self, state: np.ndarray, action: np.ndarray, dt: float = 0.05) -> np.ndarray:
        """Van der Pol with saturation"""
        x1, x2 = state
        u = np.clip(action[0] if len(action) > 0 else 0.0, -2.0, 2.0)

        # Van der Pol with saturation
        x1_dot = x2
        x2_dot = self.mu * (1 - x1**2) * x2 - x1 + u

        # Saturate derivatives
        x1_dot = np.clip(x1_dot, -10, 10)
        x2_dot = np.clip(x2_dot, -10, 10)

        # Integration
        x1_next = x1 + x1_dot * dt
        x2_next = x2 + x2_dot * dt

        return np.array([x1_next, x2_next])

# =============================================
# ROBUST NONLINEAR ENVIRONMENT
# =============================================

class RobustNonlinearEnvironment:
    """Improved nonlinear environment with better numerical stability"""

    def __init__(self, regimes_config: dict, noise_std=0.01, control_limit=5.0,
                 cost_scaling=1.0, dt=0.05):  # Reduced noise and control limits
        self.noise_std = noise_std
        self.control_limit = control_limit
        self.cost_scaling = cost_scaling
        self.dt = dt

        self.regimes = {}
        for name, config in regimes_config.items():
            self.regimes[name] = NonlinearRegime(
                name=name,
                dynamics=config['dynamics'],
                Q=config['Q'] * cost_scaling,
                R=config['R'] * cost_scaling
            )

        self.current_regime = None
        self.state = None
        self.time = 0
        self._noise_sequence = None

        first_regime = next(iter(self.regimes.values()))
        self.state_dim = first_regime.dynamics.state_dim
        self.control_dim = first_regime.dynamics.control_dim

    def reset(self, regime_name: str, initial_state=None, noise_sequence=None):
        """Reset with smaller initial states"""
        self.current_regime = self.regimes[regime_name]
        if initial_state is not None:
            self.state = initial_state.copy()
        else:
            # Smaller initial states for better stability
            self.state = np.random.normal(0, 0.1, self.state_dim)

        self.time = 0
        self._noise_sequence = noise_sequence
        self._current_noise_step = 0
        return self.state.copy()

    def step(self, action):
        """Step with better error handling"""
        if self.current_regime is None:
            raise ValueError("Environment not reset")

        action = np.clip(action, -self.control_limit, self.control_limit)

        # Get noise
        if self._noise_sequence is not None and self._current_noise_step < len(self._noise_sequence):
            noise = self._noise_sequence[self._current_noise_step]
            self._current_noise_step += 1
        else:
            noise = np.random.normal(0, self.noise_std, self.state_dim)

        current_state = self.state.copy()

        # Apply dynamics with error handling
        try:
            next_state = self.current_regime.dynamics.dynamics(self.state, action, self.dt)
            self.state = next_state + noise

            # State bounds for stability
            self.state = np.clip(self.state, -50, 50)

        except Exception as e:
            print(f"Dynamics failed: {e}, using previous state")
            # Keep previous state if dynamics fail

        # Compute cost
        Q = self.current_regime.Q
        R = self.current_regime.R
        cost = current_state.T @ Q @ current_state + action.T @ R @ action
        cost = float(cost)  # Ensure scalar

        reward = -cost
        self.time += 1
        done = self.time >= 100  # Shorter episodes

        info = {
            'regime': self.current_regime.name,
            'cost': cost,
            'state_cost': float(current_state.T @ Q @ current_state),
            'control_cost': float(action.T @ R @ action)
        }

        return self.state.copy(), reward, done, info

    def get_optimal_policy(self, regime_name: str, state=None):
        """Robust linearized LQR with error handling"""
        if state is None:
            state = self.state

        regime = self.regimes[regime_name]

        try:
            # Linearize around current state with zero action
            zero_action = np.zeros(self.control_dim)
            A, B = regime.dynamics.linearize(state, zero_action, self.dt)

            # Check controllability
            if not self._is_controllable(A, B):
                print(f"System not controllable for {regime_name}, using fallback")
                return self._fallback_controller()

            # Solve ARE with error handling
            P = solve_discrete_are(A, B, regime.Q, regime.R)

            # Check if solution is positive definite
            if np.any(np.linalg.eigvals(P) <= 0):
                print(f"ARE solution not positive definite for {regime_name}")
                return self._fallback_controller()

            # Compute gains
            try:
                K_optimal = np.linalg.inv(regime.R + B.T @ P @ B) @ (B.T @ P @ A)
            except np.linalg.LinAlgError:
                print(f"Gain computation failed for {regime_name}")
                return self._fallback_controller()

            # Check if gains are reasonable
            if np.any(np.abs(K_optimal) > 100):
                print(f"Large gains detected for {regime_name}, scaling down")
                K_optimal = K_optimal / np.max(np.abs(K_optimal)) * 10

            return K_optimal

        except Exception as e:
            print(f"LQR computation failed for {regime_name}: {e}")
            return self._fallback_controller()

    def _is_controllable(self, A, B, tol=1e-10):
        """Check controllability"""
        n = A.shape[0]
        C = B.copy()
        for i in range(1, n):
            C = np.hstack([C, np.linalg.matrix_power(A, i) @ B])
        return np.linalg.matrix_rank(C, tol=tol) == n

    def _fallback_controller(self):
        """Fallback controller when LQR fails"""
        return np.random.normal(0, 0.1, (self.control_dim, self.state_dim))

    def compute_cost(self, regime_name: str, state, action):
        """Robust cost computation"""
        Q = self.regimes[regime_name].Q
        R = self.regimes[regime_name].R

        is_tensor = isinstance(state, torch.Tensor) or isinstance(action, torch.Tensor)

        if is_tensor:
            Q_tensor = torch.FloatTensor(Q)
            R_tensor = torch.FloatTensor(R)
            state_tensor = state if isinstance(state, torch.Tensor) else torch.FloatTensor(state)
            action_tensor = action if isinstance(action, torch.Tensor) else torch.FloatTensor(action)

            state_cost = state_tensor.T @ Q_tensor @ state_tensor
            control_cost = action_tensor.T @ R_tensor @ action_tensor
            total_cost = state_cost + control_cost

            # Check for inf/nan
            if torch.isinf(total_cost) or torch.isnan(total_cost):
                return torch.tensor(1e6)  # Large penalty

            return total_cost
        else:
            state_cost = state.T @ Q @ state
            control_cost = action.T @ R @ action
            total_cost = state_cost + control_cost

            # Check for inf/nan
            if np.isinf(total_cost) or np.isnan(total_cost):
                return 1e6  # Large penalty

            return float(total_cost)

    def compute_optimal_cost(self, regime_name: str, state):
        """Robust optimal cost computation"""
        try:
            K_opt = self.get_optimal_policy(regime_name, state)
            optimal_action = -K_opt @ state
            saturated_optimal_action = np.clip(optimal_action, -self.control_limit, self.control_limit)
            return self.compute_cost(regime_name, state, saturated_optimal_action)
        except Exception as e:
            print(f"Optimal cost computation failed: {e}")
            return 1e6

# =============================================
# CONTROLLER IMPLEMENTATIONS
# =============================================

class LQRController:
    """Linear controller (uses linearized policies)"""

    def __init__(self, state_dim, control_dim, regime_name=None):
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.regime_name = regime_name
        self.K = np.random.normal(0, 0.1, (control_dim, state_dim))

    def set_gains(self, K):
        self.K = K.copy()

    def predict(self, state):
        return -self.K @ state

class NeuralController(nn.Module):
    """Neural network controller (reused from original)"""

    def __init__(self, state_dim: int, control_dim: int, config: NeuralControllerConfig, regime_name: Optional[str] = None):
        super().__init__()
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.regime_name = regime_name
        self.config = config

        layers = []
        layers.append(nn.Linear(state_dim, config.hidden_size))
        layers.append(config.activation())

        for _ in range(config.num_layers - 1):
            layers.append(nn.Linear(config.hidden_size, config.hidden_size))
            layers.append(config.activation())

        layers.append(nn.Linear(config.hidden_size, control_dim))
        self.network = nn.Sequential(*layers)

        for layer in self.network:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight, gain=0.01)
                nn.init.zeros_(layer.bias)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state)
        return self.network(state)

    def predict(self, state: np.ndarray) -> np.ndarray:
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state) if isinstance(state, np.ndarray) else state
            action_tensor = self.network(state_tensor)
            return action_tensor.numpy() if isinstance(action_tensor, torch.Tensor) else action_tensor

    def train_on_regime(self, env, regime_name: str, episodes: int = 100, lr: float = 0.001):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        losses = []

        for episode in range(episodes):
            state_np = env.reset(regime_name)
            training_episode_length = 50
            noise_sequence = [np.random.normal(0, env.noise_std, env.state_dim) for _ in range(training_episode_length)]
            env.reset(regime_name, initial_state=state_np.copy(), noise_sequence=noise_sequence.copy())

            episode_loss = 0

            for step in range(training_episode_length):
                state_tensor = torch.FloatTensor(state_np)
                action_tensor = self.forward(state_tensor)
                action_np = action_tensor.detach().numpy()
                next_state_np, reward, done, info = env.step(action_np)

                current_cost_tensor = env.compute_cost(regime_name, state_tensor, action_tensor)

                if episode_loss == 0:
                    episode_loss = current_cost_tensor
                else:
                    episode_loss = episode_loss + current_cost_tensor

                if done:
                    break

                state_np = next_state_np

            optimizer.zero_grad()
            episode_loss.backward()
            optimizer.step()
            losses.append(episode_loss.item())

        return losses

# =============================================
# ENSEMBLE IMPLEMENTATIONS 
# =============================================

class BayesianEnsembleWeights:
    def __init__(self, n_models, alpha_prior=1.0, learning_rate=0.01):
        self.n_models = n_models
        self.alpha_prior = alpha_prior
        self.learning_rate = learning_rate
        self.weights = np.ones(n_models) / n_models
        self.alpha = np.ones(n_models) * alpha_prior

    def update_weights(self, individual_costs, ensemble_cost):
        min_cost = np.min(individual_costs)
        rewards = -individual_costs + min_cost + 1e-6
        self.alpha += rewards * self.learning_rate
        self.alpha = np.maximum(self.alpha, 1e-9)
        self.weights = np.random.dirichlet(self.alpha)
        self.weights = np.abs(self.weights)
        self.weights = self.weights / np.sum(self.weights)
        return self.weights

class FixedUniformWeights:
    def __init__(self, n_models):
        self.n_models = n_models
        self.weights = np.ones(n_models) / n_models

    def update_weights(self, individual_costs, ensemble_cost):
        return self.weights

class EnsembleController:
    def __init__(self, controllers, weight_learner):
        self.controllers = controllers
        self.weight_learner = weight_learner
        self.n_models = len(controllers)

    def predict(self, state):
        individual_actions = []
        for controller in self.controllers:
            action = controller.predict(state)
            individual_actions.append(action)

        individual_actions = np.array(individual_actions)
        weights = self.weight_learner.weights
        ensemble_action = np.sum(weights.reshape(-1, 1) * individual_actions, axis=0)
        return ensemble_action, individual_actions

class LinearizedLQREnsemble(EnsembleController):
    """LQR ensemble using linearized policies"""

    def __init__(self, env, regime_names, fixed_weights=False):
        controllers = []
        for regime_name in regime_names:
            controller = LQRController(env.state_dim, env.control_dim, regime_name)
            controllers.append(controller)

        weight_learner = FixedUniformWeights(len(controllers)) if fixed_weights else BayesianEnsembleWeights(len(controllers))
        super().__init__(controllers, weight_learner)
        self.env = env
        self.regime_names = regime_names

    def predict(self, state):
        # Update each controller with current linearized policy
        for i, (controller, regime_name) in enumerate(zip(self.controllers, self.regime_names)):
            K_linear = self.env.get_optimal_policy(regime_name, state)
            controller.set_gains(K_linear)

        return super().predict(state)

class NeuralEnsemble(EnsembleController):
    def __init__(self, env, regime_names, nn_config: NeuralControllerConfig, training_config: TrainingConfig, fixed_weights=False):
        controllers = []
        for regime_name in regime_names:
            controller = NeuralController(env.state_dim, env.control_dim, config=nn_config, regime_name=regime_name)
            print(f"  Training Neural Controller for '{regime_name}'...")
            controller.train_on_regime(env, regime_name, episodes=training_config.episodes, lr=training_config.lr)
            controllers.append(controller)

        weight_learner = FixedUniformWeights(len(controllers)) if fixed_weights else BayesianEnsembleWeights(len(controllers))
        super().__init__(controllers, weight_learner)

        self.env = env
        self.regime_names = regime_names
        self.nn_config = nn_config
        self.training_config = training_config

# =============================================
# NONLINEAR EXPERIMENT FRAMEWORK
# =============================================

class NonlinearConvexityExperiment:
    """Experiment framework for nonlinear systems"""

    def __init__(self, env, n_trials=10, n_seeds=3,
                 nn_config: Optional[NeuralControllerConfig] = None,
                 training_config: Optional[TrainingConfig] = None):
        self.env = env
        self.n_trials = n_trials
        self.n_seeds = n_seeds
        self.nn_config = nn_config or NeuralControllerConfig()
        self.training_config = training_config or TrainingConfig()
        self.regime_names = list(env.regimes.keys())
        self.results = {}

    def run_single_trial(self, trial_idx):
        """Run single trial"""
        print(f"Running trial {trial_idx + 1}/{self.n_trials}")

        # Create ensembles
        linear_ensemble = LinearizedLQREnsemble(self.env, self.regime_names, fixed_weights=False)
        neural_ensemble = NeuralEnsemble(self.env, self.regime_names,
                                       self.nn_config, self.training_config, fixed_weights=False)

        regime_sequence = self.regime_names * 5
        episode_length = 100

        trial_results = {
            'linear_costs': [], 'neural_costs': [], 'oracle_costs': [],
            'linear_trajectories': [], 'neural_trajectories': [], 'oracle_trajectories': [],
            'convexity_violations': []
        }

        for regime_name in regime_sequence:
            initial_state = np.random.normal(0, 0.5, self.env.state_dim)
            noise_sequence = [np.random.normal(0, self.env.noise_std, self.env.state_dim)
                            for _ in range(episode_length)]

            # Run Linearized LQR Ensemble
            self.env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            linear_cost, linear_traj = self._run_episode(linear_ensemble, regime_name, episode_length)

            # Run Neural Ensemble
            self.env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            neural_cost, neural_traj = self._run_episode(neural_ensemble, regime_name, episode_length)

            # Run Oracle (linearized at each step)
            self.env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            oracle_cost, oracle_traj = self._run_oracle_episode(regime_name, episode_length)

            # Store results
            trial_results['linear_costs'].append(linear_cost)
            trial_results['neural_costs'].append(neural_cost)
            trial_results['oracle_costs'].append(oracle_cost)
            trial_results['linear_trajectories'].append(linear_traj)
            trial_results['neural_trajectories'].append(neural_traj)
            trial_results['oracle_trajectories'].append(oracle_traj)

            # Compute convexity violation
            linear_gap = linear_cost - oracle_cost
            neural_gap = neural_cost - oracle_cost
            convexity_violation = neural_gap - linear_gap
            trial_results['convexity_violations'].append(convexity_violation)

        return trial_results

    def _run_episode(self, ensemble, regime_name, episode_length):
        """Run episode with given ensemble"""
        episode_cost = 0
        trajectory = []

        for step in range(episode_length):
            current_state = self.env.state.copy()
            trajectory.append(current_state)

            action, individual_actions = ensemble.predict(current_state)
            individual_costs = np.array([self.env.compute_cost(regime_name, current_state, a)
                                       for a in individual_actions])

            next_state, reward, done, info = self.env.step(action)
            episode_cost += info['cost']

            ensemble.weight_learner.update_weights(individual_costs, info['cost'])

            if done:
                break

        return episode_cost, np.array(trajectory)

    def _run_oracle_episode(self, regime_name, episode_length):
        """Run episode with oracle (linearized LQR at each step)"""
        episode_cost = 0
        trajectory = []

        for step in range(episode_length):
            current_state = self.env.state.copy()
            trajectory.append(current_state)

            # Get linearized optimal action
            K_opt = self.env.get_optimal_policy(regime_name, current_state)
            oracle_action = -K_opt @ current_state

            next_state, reward, done, info = self.env.step(oracle_action)
            episode_cost += info['cost']

            if done:
                break

        return episode_cost, np.array(trajectory)

    def run_experiment(self):
        """Run complete experiment"""
        print("="*60)
        print("NONLINEAR CONVEXITY VIOLATION EXPERIMENT")
        print(f"System: {type(list(self.env.regimes.values())[0].dynamics).__name__}")
        print("="*60)

        all_results = []

        for seed in range(self.n_seeds):
            print(f"\nSeed {seed + 1}/{self.n_seeds}")
            np.random.seed(seed)
            torch.manual_seed(seed)

            seed_results = []
            for trial in range(self.n_trials):
                trial_results = self.run_single_trial(trial)
                seed_results.append(trial_results)
            all_results.extend(seed_results)

        self.aggregate_results(all_results)
        return all_results

    def aggregate_results(self, all_results):
        """Aggregate and analyze results"""
        linear_costs = np.concatenate([r['linear_costs'] for r in all_results])
        neural_costs = np.concatenate([r['neural_costs'] for r in all_results])
        oracle_costs = np.concatenate([r['oracle_costs'] for r in all_results])
        convexity_violations = np.concatenate([r['convexity_violations'] for r in all_results])

        # Compute statistics
        linear_mean = np.mean(linear_costs)
        neural_mean = np.mean(neural_costs)
        oracle_mean = np.mean(oracle_costs)

        linear_gap = linear_mean - oracle_mean
        neural_gap = neural_mean - oracle_mean

        violation_mean = np.mean(convexity_violations)
        violation_std = np.std(convexity_violations)

        # Statistical test
        if len(neural_costs) == len(linear_costs):
            t_stat, p_value = stats.ttest_rel(neural_costs, linear_costs)
        else:
            t_stat, p_value = np.nan, np.nan

        self.results = {
            'linear_mean_cost': linear_mean,
            'neural_mean_cost': neural_mean,
            'oracle_mean_cost': oracle_mean,
            'linear_optimality_gap': linear_gap,
            'neural_optimality_gap': neural_gap,
            'mean_convexity_violation': violation_mean,
            'std_convexity_violation': violation_std,
            't_statistic': t_stat,
            'p_value': p_value,
            'all_results': all_results
        }

        self.print_summary()

    def print_summary(self):
        """Print results summary"""
        print("\n" + "="*60)
        print("NONLINEAR SYSTEM RESULTS")
        print("="*60)

        print(f"Oracle Mean Cost: {self.results['oracle_mean_cost']:.4f}")
        print(f"Linearized LQR Mean Cost: {self.results['linear_mean_cost']:.4f}")
        print(f"Neural Ensemble Mean Cost: {self.results['neural_mean_cost']:.4f}")

        print(f"\nOptimality Gaps:")
        print(f"  Linearized LQR Gap: {self.results['linear_optimality_gap']:.4f}")
        print(f"  Neural Gap: {self.results['neural_optimality_gap']:.4f}")

        print(f"\nConvexity Violation: {self.results['mean_convexity_violation']:.4f} ± {self.results['std_convexity_violation']:.4f}")
        print(f"Statistical significance: p = {self.results.get('p_value', np.nan):.6f}")

# =============================================
# IMPROVED NEURAL CONTROLLER
# =============================================

class ImprovedNeuralController(nn.Module):
    """Neural controller with better training stability"""

    def __init__(self, state_dim: int, control_dim: int, config, regime_name: Optional[str] = None):
        super().__init__()
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.regime_name = regime_name
        self.config = config

        layers = []
        layers.append(nn.Linear(state_dim, config.hidden_size))
        layers.append(nn.BatchNorm1d(config.hidden_size))  # Add batch norm
        layers.append(config.activation())
        layers.append(nn.Dropout(0.1))  # Add dropout

        for _ in range(config.num_layers - 1):
            layers.append(nn.Linear(config.hidden_size, config.hidden_size))
            layers.append(nn.BatchNorm1d(config.hidden_size))
            layers.append(config.activation())
            layers.append(nn.Dropout(0.1))

        layers.append(nn.Linear(config.hidden_size, control_dim))
        layers.append(nn.Tanh())  # Output activation for bounded control

        self.network = nn.Sequential(*layers)

        # Better initialization
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight, gain=0.1)  # Smaller initial weights
                nn.init.zeros_(layer.bias)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state)

        # Add batch dimension if needed
        if len(state.shape) == 1:
            state = state.unsqueeze(0)
            squeeze_output = True
        else:
            squeeze_output = False

        output = self.network(state)

        if squeeze_output:
            output = output.squeeze(0)

        # Scale output to control limits
        return output * 2.0  # Scale to [-2, 2] range

    def predict(self, state: np.ndarray) -> np.ndarray:
        with torch.no_grad():
            self.eval()  # Set to eval mode
            state_tensor = torch.FloatTensor(state) if isinstance(state, np.ndarray) else state
            action_tensor = self.forward(state_tensor)
            return action_tensor.numpy() if isinstance(action_tensor, torch.Tensor) else action_tensor

    def train_on_regime(self, env, regime_name: str, episodes: int = 200, lr: float = 0.0001):  # More episodes, lower LR
        """Improved training with better stability"""
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.8)
        losses = []

        for episode in range(episodes):
            self.train()  # Set to training mode
            state_np = env.reset(regime_name)
            episode_length = 25  # Shorter episodes

            episode_loss = 0
            states = []
            actions = []
            costs = []

            for step in range(episode_length):
                state_tensor = torch.FloatTensor(state_np)
                action_tensor = self.forward(state_tensor)
                action_np = action_tensor.detach().numpy()

                # Clip actions
                action_np = np.clip(action_np, -env.control_limit, env.control_limit)

                try:
                    next_state_np, reward, done, info = env.step(action_np)
                    cost = info['cost']

                    # Check for reasonable cost
                    if np.isfinite(cost) and cost < 1e6:
                        states.append(state_tensor)
                        actions.append(action_tensor)
                        costs.append(cost)

                    state_np = next_state_np

                    if done:
                        break

                except Exception as e:
                    print(f"Training step failed: {e}")
                    break

            # Batch training
            if len(states) > 5:  # Only train if we have enough samples
                batch_states = torch.stack(states)
                batch_actions = torch.stack(actions)
                batch_costs = torch.FloatTensor(costs)

                # Normalize costs
                if len(costs) > 1:
                    batch_costs = (batch_costs - batch_costs.mean()) / (batch_costs.std() + 1e-8)

                total_loss = torch.mean(batch_costs)

                optimizer.zero_grad()
                total_loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)

                optimizer.step()
                scheduler.step()

                losses.append(total_loss.item())
            else:
                losses.append(1e6)  # Large penalty for failed episodes

        return losses

# =============================================
# IMPROVED VISUALIZATION WITH LOG SCALE
# =============================================

def create_robust_visualizations(experiments_results: dict, save_figs=True):
    """Create visualizations that handle large dynamic ranges"""

    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.4, wspace=0.3) # Adjusted grid size

    # 1. Performance Comparison with Log Scale
    ax1 = fig.add_subplot(gs[0, :2])
    create_log_scale_comparison(ax1, experiments_results)

    # 2. Relative Performance (handles inf/nan better)
    ax2 = fig.add_subplot(gs[0, 2:])
    create_relative_performance_robust(ax2, experiments_results)

    # 3. Statistical Significance (filtered)
    ax3 = fig.add_subplot(gs[1, :2])
    create_filtered_significance(ax3, experiments_results)

    # 4. Cost Ratios (more robust metric)
    ax4 = fig.add_subplot(gs[1, 2:])
    create_cost_ratios(ax4, experiments_results)

    # 5-6. Individual system summaries (only 2 now)
    valid_systems = [sys for sys in experiments_results.keys()
                    if not np.isnan(experiments_results[sys].get('linear_mean_cost', np.nan))]

    for i, (system, results) in enumerate(experiments_results.items()):
        if i < 2: # Only for first two systems
            ax = fig.add_subplot(gs[2, i])
            create_system_summary(ax, system, results)

    # Hide the third subplot if it exists from the 3x3 grid
    if len(experiments_results) < 3:
        # Check if the subplot actually exists before trying to delete it
        if gs[2, 2] in fig.axes:
             fig.delaxes(gs[2, 2])


    plt.suptitle('Robust Nonlinear Systems Analysis (Pendulum and VanDerPol)', fontsize=18, fontweight='bold') # Updated title

    if save_figs:
        plt.savefig('robust_nonlinear_analysis.png', dpi=300, bbox_inches='tight')

    plt.show()

def create_log_scale_comparison(ax, experiments_results):
    """Performance comparison with log scale"""
    systems = list(experiments_results.keys())
    oracle_costs = []
    linear_costs = []
    neural_costs = []

    for sys in systems:
        results = experiments_results[sys]
        oracle_cost = results.get('oracle_mean_cost', np.nan)
        linear_cost = results.get('linear_mean_cost', np.nan)
        neural_cost = results.get('neural_mean_cost', np.nan)

        oracle_costs.append(max(oracle_cost, 1e-3) if not np.isnan(oracle_cost) else 1e-3)
        linear_costs.append(max(linear_cost, 1e-3) if not np.isnan(linear_cost) else 1e-3)
        neural_costs.append(max(neural_cost, 1e-3) if not np.isnan(neural_cost) else 1e-3)


    x = np.arange(len(systems))
    width = 0.25

    ax.bar(x - width, oracle_costs, width, label='Oracle', alpha=0.8)
    ax.bar(x, linear_costs, width, label='Linearized LQR', alpha=0.8)
    ax.bar(x + width, neural_costs, width, label='Neural Ensemble', alpha=0.8)

    ax.set_yscale('log')
    ax.set_xlabel('System Type')
    ax.set_ylabel('Mean Episode Cost (Log Scale)')
    ax.set_title('Performance Comparison (Log Scale)')
    ax.set_xticks(x)
    ax.set_xticklabels(systems, rotation=45)
    ax.legend()
    ax.grid(True, alpha=0.3)

def create_relative_performance_robust(ax, experiments_results):
    """Relative performance that handles failures"""
    systems = []
    relative_gaps = []

    for system, results in experiments_results.items():
        linear_cost = results.get('linear_mean_cost', np.nan)
        neural_cost = results.get('neural_mean_cost', np.nan)
        oracle_cost = results.get('oracle_mean_cost', np.nan)

        if not np.isnan(linear_cost) and linear_cost > 0 and not np.isnan(neural_cost) and not np.isnan(oracle_cost):
            neural_gap = neural_cost - oracle_cost
            linear_gap = linear_cost - oracle_cost

            if linear_gap > 0:
                relative_gap = (neural_gap - linear_gap) / linear_gap * 100
                systems.append(system)
                relative_gaps.append(min(relative_gap, 1000))  # Cap at 1000%

    if systems:
        bars = ax.bar(systems, relative_gaps, color='orange', alpha=0.7)
        ax.set_ylabel('Relative Performance Loss (%)')
        ax.set_title('Relative Performance Loss (Capped at 1000%)')
        ax.grid(True, alpha=0.3)

        for bar, gap in zip(bars, relative_gaps):
            height = bar.get_height()
            ax.annotate(f'{gap:.0f}%',
                       xy=(bar.get_x() + bar.get_width() / 2, height),
                       xytext=(0, 3),
                       textcoords="offset points",
                       ha='center', va='bottom')
    else:
        ax.text(0.5, 0.5, "No valid data for plot", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
        ax.set_title('Relative Performance Loss (Capped at 1000%)')
        ax.axis('off')


def create_filtered_significance(ax, experiments_results):
    """Statistical significance for valid results only"""
    valid_systems = []
    p_values = []

    for system, results in experiments_results.items():
        p_val = results.get('p_value', np.nan)
        if not np.isnan(p_val):
            valid_systems.append(system)
            p_values.append(-np.log10(max(p_val, 1e-16)))  # Cap very small p-values

    if valid_systems:
        bars = ax.bar(valid_systems, p_values, color='purple', alpha=0.7)
        ax.axhline(y=-np.log10(0.05), color='red', linestyle='--', label='p=0.05 threshold')
        ax.set_ylabel('-log10(p-value)')
        ax.set_title('Statistical Significance')
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, "No valid data for plot", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
        ax.set_title('Statistical Significance')
        ax.axis('off')


def create_cost_ratios(ax, experiments_results):
    """Cost ratios (more robust than raw differences)"""
    systems = []
    ratios = []

    for system, results in experiments_results.items():
        linear_cost = results.get('linear_mean_cost', np.nan)
        neural_cost = results.get('neural_mean_cost', np.nan)

        if not np.isnan(linear_cost) and linear_cost > 0 and not np.isnan(neural_cost):
            ratio = neural_cost / linear_cost
            systems.append(system)
            ratios.append(min(ratio, 100))  # Cap at 100x

    if systems:
        bars = ax.bar(systems, ratios, color='green', alpha=0.7)
        ax.axhline(y=1, color='red', linestyle='--', label='Equal Performance')
        ax.set_ylabel('Neural Cost / Linear Cost Ratio')
        ax.set_title('Cost Ratios (Neural vs Linear)')
        ax.set_yscale('log')
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, "No valid data for plot", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
        ax.set_title('Cost Ratios (Neural vs Linear)')
        ax.axis('off')


def create_system_summary(ax, system_name, results):
    """Individual system summary"""
    ax.text(0.1, 0.8, f"{system_name}", fontsize=14, fontweight='bold', transform=ax.transAxes)

    oracle_cost = results.get('oracle_mean_cost', np.nan)
    linear_cost = results.get('linear_mean_cost', np.nan)
    neural_cost = results.get('neural_mean_cost', np.nan)
    violation = results.get('mean_convexity_violation', np.nan)

    ax.text(0.1, 0.6, f"Oracle: {oracle_cost:.1f}" if not np.isnan(oracle_cost) else "Oracle: N/A", transform=ax.transAxes)
    if not np.isnan(linear_cost):
        ax.text(0.1, 0.5, f"Linear: {linear_cost:.1f}", transform=ax.transAxes)
    else:
        ax.text(0.1, 0.5, f"Linear: FAILED", color='red', transform=ax.transAxes)

    ax.text(0.1, 0.4, f"Neural: {neural_cost:.1f}" if not np.isnan(neural_cost) else "Neural: N/A", transform=ax.transAxes)

    if not np.isnan(violation):
        ax.text(0.1, 0.3, f"Violation: {violation:.1f}", transform=ax.transAxes)
    else:
        ax.text(0.1, 0.3, f"Violation: N/A", color='red', transform=ax.transAxes)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')

# =============================================
# IMPROVED MAIN FUNCTION
# =============================================

def run_robust_nonlinear_comparison():
    """Run robust nonlinear comparison with well-conditioned systems

    This refactored experiment tests three nonlinear systems:
    1. Pendulum: Nonlinear trigonometric dynamics with gravity
    2. DoubleIntegrator: Mild cubic nonlinearity (replaces problematic CartPole)
    3. VanDerPol: Limit cycle oscillator dynamics

    All systems are 2D, stable, and well-suited for linearized LQR comparison.
    """

    # More conservative system configurations
    systems_configs = {
        'Pendulum': {
            'tracking': {
                'dynamics': RobustPendulumDynamics(mass=1.0, length=1.0, damping=0.5),
                'Q': np.diag([10, 1]),
                'R': np.array([[1]])
            },
            'regulation': {
                'dynamics': RobustPendulumDynamics(mass=1.0, length=1.0, damping=0.5),
                'Q': np.diag([1, 1]),
                'R': np.array([[5]])
            },
            'stabilization': {
                'dynamics': RobustPendulumDynamics(mass=1.0, length=1.0, damping=0.5),
                'Q': np.diag([5, 2]),
                'R': np.array([[2]])
            }
        },

        # Removed CartPole system
        # 'CartPole': {
        #     'tracking': {
        #         'dynamics': RobustCartPoleDynamics(cart_mass=1.0, pole_mass=0.1, pole_length=0.25),
        #         'Q': np.diag([1, 0.1, 10, 1]),
        #         'R': np.array([[1]])
        #     },
        #     'regulation': {
        #         'dynamics': RobustCartPoleDynamics(cart_mass=1.0, pole_mass=0.1, pole_length=0.25),
        #         'Q': np.diag([0.1, 0.1, 1, 0.1]),
        #         'R': np.array([[5]])
        #     },
        #     'stabilization': {
        #         'dynamics': RobustCartPoleDynamics(cart_mass=1.0, pole_mass=0.1, pole_length=0.25),
        #         'Q': np.diag([0.5, 0.1, 5, 0.5]),
        #         'R': np.array([[2]])
        #     }
        # },

        'VanDerPol': {
            'tracking': {
                'dynamics': RobustVanDerPolDynamics(mu=0.5),
                'Q': np.diag([5, 1]),
                'R': np.array([[1]])
            },
            'regulation': {
                'dynamics': RobustVanDerPolDynamics(mu=0.5),
                'Q': np.diag([1, 1]),
                'R': np.array([[5]])
            },
            'stabilization': {
                'dynamics': RobustVanDerPolDynamics(mu=0.5),
                'Q': np.diag([2, 1]),
                'R': np.array([[2]])
            }
        }
    }

    experiments_results = {}

    for system_name, regimes_config in systems_configs.items():
        print(f"\n{'='*60}")
        print(f"Running experiments for {system_name}")
        print(f"{'='*60}")

        try:
            # Create environment with more conservative settings
            env = RobustNonlinearEnvironment(regimes_config,
                                           noise_std=0.01,
                                           cost_scaling=0.1)

            # Test if environment works
            test_state = env.reset(list(regimes_config.keys())[0])
            test_action = np.zeros(env.control_dim)
            env.step(test_action)

            # Run experiment with robust neural controller
            from dataclasses import dataclass
            @dataclass
            class NeuralConfig:
                hidden_size: int = 32
                num_layers: int = 2
                activation = nn.Tanh

            @dataclass
            class TrainConfig:
                episodes: int = 100
                lr: float = 0.0001

            experiment = NonlinearConvexityExperiment(
                env=env,
                n_trials=3,  # Fewer trials for debugging
                n_seeds=2,   # Fewer seeds
                nn_config=NeuralConfig(),
                training_config=TrainConfig()
            )

            results = experiment.run_experiment()
            experiments_results[system_name] = experiment.results

        except Exception as e:
            print(f"System {system_name} failed: {e}")
            # Create dummy results for failed systems
            experiments_results[system_name] = {
                'oracle_mean_cost': np.nan,
                'linear_mean_cost': np.nan,
                'neural_mean_cost': np.nan,
                'mean_convexity_violation': np.nan,
                'std_convexity_violation': np.nan,
                'p_value': np.nan
            }

    # Create robust visualizations
    print(f"\n{'='*60}")
    print("Creating robust visualizations...")
    print(f"{'='*60}")

    create_robust_visualizations(experiments_results, save_figs=True)

    # Print summary with error handling
    print(f"\n{'='*60}")
    print("ROBUST SUMMARY TABLE")
    print(f"{'='*60}")

    df_data = []
    for system, results in experiments_results.items():
        oracle_cost = results.get('oracle_mean_cost', np.nan)
        linear_cost = results.get('linear_mean_cost', np.nan)
        neural_cost = results.get('neural_mean_cost', np.nan)
        violation = results.get('mean_convexity_violation', np.nan)
        p_val = results.get('p_value', np.nan)

        df_data.append({
            'System': system,
            'Oracle Cost': f"{oracle_cost:.1f}" if not np.isnan(oracle_cost) else "FAILED",
            'Linear Cost': f"{linear_cost:.1f}" if not np.isnan(linear_cost) else "FAILED",
            'Neural Cost': f"{neural_cost:.1f}" if not np.isnan(neural_cost) else "FAILED",
            'Violation': f"{violation:.1f}" if not np.isnan(violation) else "N/A",
            'P-value': f"{p_val:.2e}" if not np.isnan(p_val) else "N/A"
        })

    df = pd.DataFrame(df_data)
    print(df.to_string(index=False))

    return experiments_results

# Run the robust comparison
if __name__ == "__main__":
    results = run_robust_nonlinear_comparison()