#!/usr/bin/env python3
"""
Experimental Framework: Convexity Violation Analysis in Ensemble Control
Primary Research Question: How much performance is lost when using neural network
ensembles compared to the theoretical bounds guaranteed by LQR ensembles?
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.linalg import solve_discrete_are
from sklearn.preprocessing import StandardScaler
from collections import defaultdict, deque
import warnings
import copy
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)

# =============================================
# MULTI-REGIME LQR ENVIRONMENT
# =============================================

@dataclass
class LQRRegime:
    """Defines a specific LQR regime with known optimal solution"""
    name: str
    A: np.ndarray  # State transition matrix
    B: np.ndarray  # Control input matrix
    Q: np.ndarray  # State cost matrix
    R: np.ndarray  # Control cost matrix
    K_optimal: np.ndarray = None # Optimal LQR gain (computed)

    def __post_init__(self):
        """Compute optimal LQR solution and verify stability"""
        # Stability Verification
        eigenvalues = np.linalg.eigvals(self.A)
        if np.any(np.abs(eigenvalues) >= 1):
            print(f"Warning: Regime '{self.name}' A matrix has eigenvalues >= 1: {eigenvalues}. System may be unstable.")

        # Scale Normalization (reduce magnitudes)
        # self.Q = self.Q / 10.0  # Reduced Q magnitude - removed scaling here, applied in env step
        # self.R = self.R / 10.0  # Reduced R magnitude - removed scaling here, applied in env step


        try:
            P = solve_discrete_are(self.A, self.B, self.Q, self.R)
            self.K_optimal = np.linalg.inv(self.R + self.B.T @ P @ self.B) @ (self.B.T @ P @ self.A)
        except Exception as e:
            print(f"Warning: Failed to compute optimal LQR for {self.name}: {e}")
            # Fallback to identity gain scaled appropriately
            self.K_optimal = np.eye(self.B.shape[1], self.A.shape[0]) * 0.1

class MultiRegimeLQREnvironment:
    """Multi-regime LQR environment with regime switching"""

    def __init__(self, state_dim=4, control_dim=2, noise_std=0.1, control_limit=10.0, cost_scaling=1.0):
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.noise_std = noise_std
        self.control_limit = control_limit # Control saturation limit
        self.cost_scaling = cost_scaling # Scaling factor for Q and R

        # Define distinct regimes with different objectives
        self.regimes = self._create_regimes()
        self.current_regime = None
        self.state = None
        self.time = 0
        self._noise_sequence = None

    def _create_regimes(self):
        """Create distinct LQR regimes with different control objectives"""
        regimes = {}

        # Regime 1: Tracking (minimize state error)
        A1 = np.array([[0.95, 0.05, 0, 0], # Made A stable
                       [0, 0.98, 0.05, 0],
                       [0, 0, 0.90, 0.05],
                       [0, 0, 0, 0.92]])
        B1 = np.array([[1, 0], [0, 1], [0.5, 0], [0, 0.5]])
        Q1 = np.diag([100, 100, 10, 10])  # High state penalty
        R1 = np.diag([1, 1])              # Low control penalty

        regimes['tracking'] = LQRRegime('tracking', A1, B1, Q1, R1, None)

        # Regime 2: Regulation (minimize control energy)
        A2 = np.array([[0.92, 0, 0.05, 0], # Made A stable
                       [0.05, 0.90, 0, 0.05],
                       [0, 0, 0.95, 0],
                       [0, 0, 0.05, 0.98]])
        B2 = np.array([[0.8, 0.2], [0.2, 0.8], [1, 0], [0, 1]])
        Q2 = np.diag([1, 1, 1, 1])       # Low state penalty
        R2 = np.diag([50, 50])           # High control penalty

        regimes['regulation'] = LQRRegime('regulation', A2, B2, Q2, R2, None)

        # Regime 3: Stabilization (mixed objective)
        A3 = np.array([[0.90, 0.02, 0, 0.02], # Made A stable
                       [0, 0.95, 0.02, 0],
                       [0.02, 0, 0.92, 0],
                       [0, 0.02, 0, 0.96]])
        B3 = np.array([[1, 0.1], [0.1, 1], [0.8, 0.2], [0.2, 0.8]])
        Q3 = np.diag([25, 25, 25, 25])   # Moderate state penalty
        R3 = np.diag([10, 10])           # Moderate control penalty

        regimes['stabilization'] = LQRRegime('stabilization', A3, B3, Q3, R3, None)

        return regimes

    def reset(self, regime_name='tracking', initial_state=None, noise_sequence=None):
        """Reset environment to specific regime"""
        self.current_regime = self.regimes[regime_name]
        self.state = initial_state if initial_state is not None else np.random.normal(0, 1, self.state_dim)
        self.time = 0
        self._noise_sequence = noise_sequence
        self._current_noise_step = 0
        return self.state.copy()

    def step(self, action):
        """Execute one step of the LQR dynamics"""
        if self.current_regime is None:
            raise ValueError("Environment not reset")

        # Control Saturation
        action = np.clip(action, -self.control_limit, self.control_limit)

        # Current regime dynamics
        A = self.current_regime.A
        B = self.current_regime.B
        Q = self.current_regime.Q * self.cost_scaling # Apply scaling here
        R = self.current_regime.R * self.cost_scaling # Apply scaling here

        # Apply control and dynamics
        if self._noise_sequence is not None and self._current_noise_step < len(self._noise_sequence):
            noise = self._noise_sequence[self._current_noise_step]
            self._current_noise_step += 1
        else:
             noise = np.random.normal(0, self.noise_std, self.state_dim)

        self.state = A @ self.state + B @ action + noise

        # Compute cost (negative reward for consistency)
        cost = self.state.T @ Q @ self.state + action.T @ R @ action
        reward = -cost

        self.time += 1
        done = self.time >= 200  # Episode length

        info = {
            'regime': self.current_regime.name,
            'cost': cost,
            'state_cost': self.state.T @ Q @ self.state,
            'control_cost': action.T @ R @ action
        }

        return self.state.copy(), reward, done, info

    def get_optimal_policy(self, regime_name):
        """Get optimal LQR policy for given regime"""
        return self.regimes[regime_name].K_optimal

    def compute_cost(self, regime_name, state, action):
        """Compute cost for a given state and action in a specific regime"""
        Q = self.regimes[regime_name].Q * self.cost_scaling
        R = self.regimes[regime_name].R * self.cost_scaling
        return state.T @ Q @ state + action.T @ R @ action

    def compute_optimal_cost(self, regime_name, state):
        """Compute cost using optimal policy for a given state"""
        K_opt = self.get_optimal_policy(regime_name)
        optimal_action = -K_opt @ state
        return self.compute_cost(regime_name, state, optimal_action)


# =============================================
# CONTROLLER IMPLEMENTATIONS
# =============================================

class LQRController:
    """Linear Quadratic Regulator controller"""

    def __init__(self, state_dim, control_dim, regime_name=None):
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.regime_name = regime_name
        self.K = np.random.normal(0, 0.1, (control_dim, state_dim))
        self.parameter_count = control_dim * state_dim

    def set_gains(self, K):
        """Set LQR gains directly"""
        self.K = K.copy()

    def predict(self, state):
        """Compute control action"""
        return -self.K @ state

    def get_parameters(self):
        """Get controller parameters"""
        return self.K.flatten()

    def set_parameters(self, params):
        """Set controller parameters"""
        self.K = params.reshape(self.control_dim, self.state_dim)

class NeuralController(nn.Module):
    """Neural network controller with matched parameter count"""

    def __init__(self, state_dim, control_dim, regime_name=None):
        super().__init__()
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.regime_name = regime_name

        # Match parameter count to LQR controller
        lqr_params = control_dim * state_dim

        # Design network with approximately same parameter count
        # Network: input -> hidden -> output
        # Parameters: (state_dim * hidden) + hidden + (hidden * control_dim) + control_dim
        # Solve: state_dim * h + h + h * control_dim + control_dim = lqr_params
        self.hidden_size = max(1, int(lqr_params / (state_dim + control_dim + 1)))

        self.network = nn.Sequential(
            nn.Linear(state_dim, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, control_dim)
        )

        # Initialize with small weights for stability
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight, gain=0.1)
                nn.init.zeros_(layer.bias)

        self.parameter_count = sum(p.numel() for p in self.parameters())

    def forward(self, state):
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state)
        return self.network(state)

    def predict(self, state):
        """Compute control action"""
        with torch.no_grad():
            action = self.forward(state)
            return action.numpy() if isinstance(action, torch.Tensor) else action

    def train_on_regime(self, env, regime_name, episodes=100, lr=0.001):
        """Train neural controller on specific regime"""
        optimizer = optim.Adam(self.parameters(), lr=lr)
        losses = []

        for episode in range(episodes):
            state = env.reset(regime_name)
            episode_loss = 0

            for step in range(50):  # Shorter episodes for training
                action = self.predict(state)
                next_state, reward, done, info = env.step(action)

                # Loss is the cost (negative reward)
                loss = -reward
                episode_loss += loss

                if done:
                    break

                state = next_state

            # Backprop on accumulated episode loss
            optimizer.zero_grad()
            loss_tensor = torch.tensor(episode_loss, requires_grad=True)
            loss_tensor.backward()
            optimizer.step()

            losses.append(episode_loss)

        return losses

# =============================================
# ENSEMBLE WEIGHT LEARNING
# =============================================

class BayesianEnsembleWeights:
    """Principled Bayesian weight learning for ensemble control"""

    def __init__(self, n_models, alpha_prior=1.0, learning_rate=0.01):
        self.n_models = n_models
        self.alpha_prior = alpha_prior
        self.learning_rate = learning_rate

        # Initialize uniform weights
        self.weights = np.ones(n_models) / n_models
        self.performance_history = []
        self.weight_history = []

        # Bayesian tracking
        self.alpha = np.ones(n_models) * alpha_prior  # Dirichlet parameters

    def update_weights(self, individual_costs, ensemble_cost):
        """Update ensemble weights based on performance"""
        self.performance_history.append({
            'individual_costs': individual_costs.copy(),
            'ensemble_cost': ensemble_cost,
            'weights': self.weights.copy()
        })

        # Bayesian update: models with lower cost get higher weight
        # Convert costs to pseudo-rewards (lower cost = higher reward)
        # Ensure rewards are positive and scale appropriately
        min_cost = np.min(individual_costs)
        rewards = -individual_costs + min_cost + 1e-6 # Shift costs so lower cost is higher reward, add epsilon

        # Update Dirichlet parameters
        self.alpha += rewards * self.learning_rate

        # Clip alpha to a small positive value to prevent issues with np.random.dirichlet
        self.alpha = np.maximum(self.alpha, 1e-9)


        # Sample new weights from posterior
        self.weights = np.random.dirichlet(self.alpha)

        # Ensure weights sum to 1 and are non-negative
        self.weights = np.abs(self.weights)
        self.weights = self.weights / np.sum(self.weights)

        self.weight_history.append(self.weights.copy())

        return self.weights

    def get_weights(self):
        """Get current ensemble weights"""
        return self.weights.copy()

class FixedUniformWeights:
    """Weight learner with fixed uniform weights for debugging"""
    def __init__(self, n_models):
        self.n_models = n_models
        self.weights = np.ones(n_models) / n_models
        self.weight_history = [self.weights.copy()]

    def update_weights(self, individual_costs, ensemble_cost):
        """Weights are fixed, no update"""
        self.weight_history.append(self.weights.copy())
        return self.weights

    def get_weights(self):
        """Get current ensemble weights"""
        return self.weights.copy()


# =============================================
# ENSEMBLE CONTROLLERS
# =============================================

class EnsembleController:
    """Base class for ensemble controllers"""

    def __init__(self, controllers, weight_learner):
        self.controllers = controllers
        self.weight_learner = weight_learner
        self.n_models = len(controllers)

        # Performance tracking
        self.individual_costs = []
        self.ensemble_costs = []
        self.weights_over_time = []

    def predict(self, state):
        """Compute ensemble control action"""
        # Get individual predictions
        individual_actions = []
        for controller in self.controllers:
            action = controller.predict(state)
            individual_actions.append(action)

        individual_actions = np.array(individual_actions)

        # Get current weights
        weights = self.weight_learner.get_weights()

        # Compute weighted ensemble action
        ensemble_action = np.sum(weights.reshape(-1, 1) * individual_actions, axis=0)

        return ensemble_action, individual_actions

    def update(self, state, ensemble_action, individual_actions, next_state, cost):
        """Update ensemble weights based on performance"""
        # Estimate individual costs (approximate) - now based on actual cost function
        individual_costs = np.array([self.controllers[i].regime_name for i in range(self.n_models)]) # Placeholder, will be calculated in run_single_trial

        # Update weight learner (will be called with actual individual costs in run_single_trial)
        # This update method in EnsembleController is now simplified, actual update happens outside
        # new_weights = self.weight_learner.update_weights(individual_costs, cost)
        # return new_weights
        pass


class LQREnsemble(EnsembleController):
    """LQR ensemble with theoretical convexity guarantees"""

    def __init__(self, env, regime_names, fixed_weights=False):
        # Create LQR controllers for each regime
        controllers = []
        for regime_name in regime_names:
            controller = LQRController(env.state_dim, env.control_dim, regime_name)
            # Set to optimal LQR gains
            K_opt = env.get_optimal_policy(regime_name)
            controller.set_gains(K_opt)
            controllers.append(controller)

        if fixed_weights:
            weight_learner = FixedUniformWeights(len(controllers))
        else:
            weight_learner = BayesianEnsembleWeights(len(controllers))

        super().__init__(controllers, weight_learner)

        self.env = env
        self.regime_names = regime_names

    def compute_theoretical_bound(self, state, regime_name):
        """Compute theoretical performance bound for convex combination"""
        # For LQR, any convex combination of optimal controllers has bounded performance
        individual_costs = []
        for controller in self.controllers:
            action = controller.predict(state)
            cost = self.env.compute_cost(regime_name, state, action)
            individual_costs.append(cost)

        # Theoretical bound: convex combination of costs
        weights = self.weight_learner.get_weights()
        theoretical_bound = np.sum(weights * individual_costs)

        return theoretical_bound

class NeuralEnsemble(EnsembleController):
    """Neural network ensemble without convexity guarantees"""

    def __init__(self, env, regime_names, training_episodes=100, fixed_weights=False):
        # Create and train neural controllers for each regime
        controllers = []
        for regime_name in regime_names:
            controller = NeuralController(env.state_dim, env.control_dim, regime_name)
            # Train on specific regime
            losses = controller.train_on_regime(env, regime_name, training_episodes)
            controllers.append(controller)

        if fixed_weights:
            weight_learner = FixedUniformWeights(len(controllers))
        else:
            weight_learner = BayesianEnsembleWeights(len(controllers))

        super().__init__(controllers, weight_learner)

        self.env = env
        self.regime_names = regime_names

# =============================================
# EXPERIMENTAL FRAMEWORK
# =============================================

class ConvexityViolationExperiment:
    """Main experimental framework for convexity violation analysis"""

    def __init__(self, state_dim=4, control_dim=2, n_trials=10, debug_fixed_weights=False):
        self.state_dim = state_dim
        self.control_dim = control_dim
        self.n_trials = n_trials
        self.debug_fixed_weights = debug_fixed_weights

        # Create environment
        self.env = MultiRegimeLQREnvironment(state_dim, control_dim)
        self.regime_names = list(self.env.regimes.keys())

        # Results storage
        self.results = defaultdict(list)

    def trace_single_optimal_lqr_step(self):
        """Trace a single step of the optimal LQR to debug cost discrepancy."""
        print("\nTracing Single Optimal LQR Step...")
        regime_name = self.regime_names[0] # Use one regime for tracing
        env = MultiRegimeLQREnvironment(self.state_dim, self.control_dim, cost_scaling=self.env.cost_scaling) # Use same cost scaling
        initial_state = env.reset(regime_name)
        noise = np.random.normal(0, env.noise_std, env.state_dim) # Single noise realization

        current_state = initial_state.copy()
        optimal_action = -env.get_optimal_policy(regime_name) @ current_state

        # Apply control saturation manually for debugging
        saturated_optimal_action = np.clip(optimal_action, -env.control_limit, env.control_limit)

        # Compute cost manually using the environment's cost function
        computed_cost = env.compute_cost(regime_name, current_state, saturated_optimal_action)

        # Step the environment with the saturated optimal action and noise
        env.state = current_state.copy() # Ensure env starts from current_state
        env._noise_sequence = [noise] # Provide single noise
        env._current_noise_step = 0
        next_state, reward, done, info = env.step(saturated_optimal_action)
        env_step_cost = info['cost']

        print(f"  Regime: {regime_name}")
        print(f"  Initial State: {current_state}")
        print(f"  Optimal Action (pre-saturation): {optimal_action}")
        print(f"  Saturated Optimal Action: {saturated_optimal_action}")
        print(f"  Noise: {noise}")
        print(f"  Computed Cost (using compute_cost): {computed_cost:.4f}")
        print(f"  Environment Step Cost (from info): {env_step_cost:.4f}")
        print(f"  Cost Difference (Computed - Env Step): {computed_cost - env_step_cost:.4f}")
        print(f"  Next State (from env step): {next_state}")
        print("Single Optimal LQR Step Trace Complete.")


    def sanity_check_lqr_controllers(self):
        """Verify LQR controllers perform reasonably on their single regimes"""
        print("\nRunning LQR Sanity Checks...")
        sanity_check_results = {}
        for regime_name in self.regime_names:
            env = MultiRegimeLQREnvironment(self.state_dim, self.control_dim, cost_scaling=self.env.cost_scaling) # Use same cost scaling
            lqr_controller = LQRController(self.state_dim, self.control_dim)
            K_opt = env.get_optimal_policy(regime_name)
            lqr_controller.set_gains(K_opt)

            # Generate fixed noise sequence for this sanity check
            noise_sequence = [np.random.normal(0, env.noise_std, env.state_dim) for _ in range(100)]
            initial_state = np.random.normal(0, 1, self.state_dim)

            state = env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            episode_cost = 0
            oracle_episode_cost = 0
            for step in range(100): # Shorter check episodes
                action = lqr_controller.predict(state)
                next_state, reward, done, info = env.step(action)
                episode_cost += info['cost']

                # Compute oracle cost for the same state
                oracle_episode_cost += env.compute_optimal_cost(regime_name, state)

                state = next_state
                if done:
                    break
            print(f"  LQR Controller for '{regime_name}': Mean Episode Cost = {episode_cost:.4f}, Oracle Mean Episode Cost = {oracle_episode_cost:.4f}, Difference = {episode_cost - oracle_episode_cost:.4f}")
            sanity_check_results[regime_name] = {'lqr_cost': episode_cost, 'oracle_cost': oracle_episode_cost}
        print("LQR Sanity Checks Complete.")
        return sanity_check_results

    def run_isolated_oracle_validation(self):
        """Validate oracle cost calculation in isolation."""
        print("\nRunning Isolated Oracle Validation...")
        env = MultiRegimeLQREnvironment(self.state_dim, self.control_dim, cost_scaling=self.env.cost_scaling) # Use same cost scaling
        regime_name = self.regime_names[0] # Use one regime for validation
        initial_state = env.reset(regime_name)
        episode_length = 100

        # Generate a fixed noise sequence for reproducibility
        noise_sequence = [np.random.normal(0, env.noise_std, env.state_dim) for _ in range(episode_length)]

        # Run oracle policy and record states and costs
        env.reset(regime_name, initial_state=initial_state, noise_sequence=noise_sequence)
        oracle_episode_cost = 0
        states_sequence = []
        for step in range(episode_length):
            current_state = env.state.copy()
            states_sequence.append(current_state)
            oracle_action = -env.get_optimal_policy(regime_name) @ current_state
            _, _, done, info = env.step(oracle_action)
            oracle_episode_cost += info['cost']
            if done:
                break

        # Re-calculate oracle cost using the recorded state sequence and optimal actions
        recalculated_oracle_cost = 0
        for state in states_sequence:
             optimal_action = -env.get_optimal_policy(regime_name) @ state
             recalculated_oracle_cost += env.compute_cost(regime_name, state, optimal_action)

        print(f"  Isolated Oracle Validation ({regime_name}):")
        print(f"    Oracle Episode Cost (Simulated): {oracle_episode_cost:.4f}")
        print(f"    Oracle Episode Cost (Recalculated): {recalculated_oracle_cost:.4f}")
        print(f"    Difference: {abs(oracle_episode_cost - recalculated_oracle_cost):.4f}")
        print("Isolated Oracle Validation Complete.")

    def run_single_controller_baseline(self):
        """Compare single LQR controller performance to oracle under identical conditions."""
        print("\nRunning Single Controller Baseline Check...")
        env = MultiRegimeLQREnvironment(self.state_dim, self.control_dim, cost_scaling=self.env.cost_scaling) # Use same cost scaling
        regime_name = self.regime_names[0] # Use one regime for validation
        initial_state = env.reset(regime_name)
        episode_length = 100

        # Generate a fixed noise sequence for reproducibility
        noise_sequence = [np.random.normal(0, env.noise_std, env.state_dim) for _ in range(episode_length)]

        # --- Run Optimal LQR Controller ---
        lqr_controller = LQRController(self.state_dim, self.control_dim)
        K_opt = env.get_optimal_policy(regime_name)
        lqr_controller.set_gains(K_opt)

        env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
        lqr_episode_cost = 0
        lqr_states_sequence = []
        for step in range(episode_length):
            current_state = env.state.copy()
            lqr_states_sequence.append(current_state)
            action = lqr_controller.predict(current_state)
            _, _, done, info = env.step(action)
            lqr_episode_cost += info['cost']
            if done:
                break

        # --- Run Oracle using LQR controller's trajectory ---
        oracle_episode_cost = 0
        for state in lqr_states_sequence:
             oracle_action = -env.get_optimal_policy(regime_name) @ state
             oracle_episode_cost += env.compute_cost(regime_name, state, oracle_action)


        print(f"  Single Controller Baseline Check ({regime_name}):")
        print(f"    Optimal LQR Episode Cost: {lqr_episode_cost:.4f}")
        print(f"    Oracle Cost on LQR Trajectory: {oracle_episode_cost:.4f}")
        print(f"    Difference: {abs(lqr_episode_cost - oracle_episode_cost):.4f}")
        print("Single Controller Baseline Check Complete.")


    def run_single_trial(self, trial_idx):
        """Run single experimental trial"""
        print(f"Running trial {trial_idx + 1}/{self.n_trials}")

        # Create ensembles, using fixed weights if in debug mode
        lqr_ensemble = LQREnsemble(self.env, self.regime_names, fixed_weights=self.debug_fixed_weights)
        neural_ensemble = NeuralEnsemble(self.env, self.regime_names, fixed_weights=self.debug_fixed_weights)

        # Test regime sequence
        regime_sequence = ['tracking', 'regulation', 'stabilization'] * 5
        episode_length = 100

        trial_results = {
            'lqr_costs': [],
            'neural_costs': [],
            'oracle_costs': [],
            'lqr_weights': [],
            'neural_weights': [],
            'theoretical_bounds': [],
            'convexity_violations': []
        }

        for regime_name in regime_sequence:
            # Generate a fixed initial state and noise sequence for this episode
            initial_state = np.random.normal(0, 1, self.state_dim)
            noise_sequence = [np.random.normal(0, self.env.noise_std, self.state_dim) for _ in range(episode_length)]

            # --- Run LQR Ensemble ---
            self.env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            lqr_episode_cost = 0
            lqr_states_sequence = []
            lqr_individual_costs_sequence = []
            for step in range(episode_length):
                current_state = self.env.state.copy()
                lqr_states_sequence.append(current_state)
                lqr_action, lqr_individual_actions = lqr_ensemble.predict(current_state)

                # Calculate individual costs for weight update *at this step*
                lqr_individual_costs = np.array([self.env.compute_cost(regime_name, current_state, action) for action in lqr_individual_actions])
                lqr_individual_costs_sequence.append(lqr_individual_costs)

                lqr_next_state, lqr_reward, lqr_done, lqr_info = self.env.step(lqr_action)
                lqr_cost = lqr_info['cost']
                lqr_episode_cost += lqr_cost

                # Update LQR weights *after* the step, only if not using fixed weights
                if not self.debug_fixed_weights:
                    lqr_ensemble.weight_learner.update_weights(lqr_individual_costs, lqr_cost)

                if lqr_done:
                    break

            # --- Run Neural Ensemble (using the same initial state and noise sequence) ---
            self.env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            neural_episode_cost = 0
            neural_states_sequence = []
            neural_individual_costs_sequence = []
            for step in range(episode_length):
                 current_state = self.env.state.copy()
                 neural_states_sequence.append(current_state)

                 neural_action, neural_individual_actions = neural_ensemble.predict(current_state)

                 # Calculate individual costs for weight update *at this step*
                 neural_individual_costs = np.array([self.env.compute_cost(regime_name, current_state, action) for action in neural_individual_actions])
                 neural_individual_costs_sequence.append(neural_individual_costs)

                 neural_next_state, neural_reward, neural_done, neural_info = self.env.step(neural_action)
                 neural_cost = neural_info['cost']
                 neural_episode_cost += neural_cost

                 # Update Neural weights *after* the step, only if not using fixed weights
                 if not self.debug_fixed_weights:
                     neural_ensemble.weight_learner.update_weights(neural_individual_costs, neural_cost)

                 if neural_done:
                     break

            # --- Run Oracle (using the same initial state and noise sequence) ---
            self.env.reset(regime_name, initial_state=initial_state.copy(), noise_sequence=noise_sequence.copy())
            oracle_episode_cost = 0
            for step in range(episode_length):
                current_state = self.env.state.copy()
                oracle_action = -self.env.get_optimal_policy(regime_name) @ current_state
                _, _, done, info = self.env.step(oracle_action) # Step the environment with oracle action
                oracle_episode_cost += info['cost'] # Use cost from env step
                if done:
                    break


            # Store episode results
            trial_results['lqr_costs'].append(lqr_episode_cost)
            trial_results['neural_costs'].append(neural_episode_cost)
            trial_results['oracle_costs'].append(oracle_episode_cost)
            trial_results['lqr_weights'].append(lqr_ensemble.weight_learner.get_weights().copy())
            trial_results['neural_weights'].append(neural_ensemble.weight_learner.get_weights().copy())

            # Compute convexity violation metrics
            lqr_optimality_gap = lqr_episode_cost - oracle_episode_cost
            neural_optimality_gap = neural_episode_cost - oracle_episode_cost
            convexity_violation = neural_optimality_gap - lqr_optimality_gap

            trial_results['convexity_violations'].append(convexity_violation)

        return trial_results

    def run_experiment(self):
        """Run complete experimental evaluation"""
        print("="*60)
        print("CONVEXITY VIOLATION EXPERIMENT")
        print(f"Research Question: Performance loss in NN vs LQR ensembles")
        print("="*60)

        # Run Debugging/Validation Tests
        self.trace_single_optimal_lqr_step() # New trace test
        self.run_isolated_oracle_validation()
        self.run_single_controller_baseline()
        self.sanity_check_lqr_controllers() # Keep existing sanity check


        all_results = []

        for trial in range(self.n_trials):
            trial_results = self.run_single_trial(trial)
            all_results.append(trial_results)

        # Aggregate results
        self.aggregate_results(all_results)

        return all_results

    def aggregate_results(self, all_results):
        """Aggregate results across trials"""
        # Combine all trials
        lqr_costs = np.concatenate([r['lqr_costs'] for r in all_results])
        neural_costs = np.concatenate([r['neural_costs'] for r in all_results])
        oracle_costs = np.concatenate([r['oracle_costs'] for r in all_results])
        convexity_violations = np.concatenate([r['convexity_violations'] for r in all_results])

        # Compute key metrics
        lqr_mean_cost = np.mean(lqr_costs)
        neural_mean_cost = np.mean(neural_costs)
        oracle_mean_cost = np.mean(oracle_costs)

        lqr_optimality_gap = lqr_mean_cost - oracle_mean_cost
        neural_optimality_gap = neural_mean_cost - oracle_mean_cost

        # Handle potential division by zero or negative LQR gap
        if lqr_optimality_gap > 1e-9: # Use a small epsilon to avoid division by near zero
             relative_performance_loss = (neural_optimality_gap - lqr_optimality_gap) / lqr_optimality_gap
        elif neural_optimality_gap > lqr_optimality_gap:
             # If neural is worse but LQR gap is zero/negative, report the difference
             relative_performance_loss = neural_optimality_gap - lqr_optimality_gap
        else:
            # If neural is better or equal and LQR gap is zero/negative
            relative_performance_loss = 0.0


        mean_convexity_violation = np.mean(convexity_violations)
        std_convexity_violation = np.std(convexity_violations)

        # Statistical significance test
        from scipy import stats
        # Use paired t-test since both ensembles are evaluated on the same trials/episodes
        t_stat, p_value = stats.ttest_rel(neural_costs, lqr_costs)

        # Store results
        self.results = {
            'lqr_mean_cost': lqr_mean_cost,
            'neural_mean_cost': neural_mean_cost,
            'oracle_mean_cost': oracle_mean_cost,
            'lqr_optimality_gap': lqr_optimality_gap,
            'neural_optimality_gap': neural_optimality_gap,
            'relative_performance_loss': relative_performance_loss,
            'mean_convexity_violation': mean_convexity_violation,
            'std_convexity_violation': std_convexity_violation,
            't_statistic': t_stat,
            'p_value': p_value,
            'all_results': all_results
        }

        self.print_summary()

    def print_summary(self):
        """Print experimental summary"""
        print("\n" + "="*60)
        print("EXPERIMENTAL RESULTS SUMMARY")
        print("="*60)

        print(f"Oracle (Optimal) Mean Cost: {self.results['oracle_mean_cost']:.4f}")
        print(f"LQR Ensemble Mean Cost: {self.results['lqr_mean_cost']:.4f}")
        print(f"Neural Ensemble Mean Cost: {self.results['neural_mean_cost']:.4f}")

        print(f"\nOptimality Gaps:")
        print(f"  LQR Gap: {self.results['lqr_optimality_gap']:.4f}")
        print(f"  Neural Gap: {self.results['neural_optimality_gap']:.4f}")

        print(f"\nKey Findings:")
        print(f"  Mean Convexity Violation: {self.results['mean_convexity_violation']:.4f} ± {self.results['std_convexity_violation']:.4f}")
        # Print relative performance loss as a percentage if the LQR gap was positive, otherwise as an absolute difference
        if self.results['lqr_optimality_gap'] > 1e-9:
             print(f"  Relative Performance Loss (vs LQR Gap): {self.results['relative_performance_loss']:.2%}")
        else:
             print(f"  Performance Difference (Neural Gap - LQR Gap): {self.results['relative_performance_loss']:.4f}")


        print(f"\nStatistical Significance:")
        print(f"  T-statistic: {self.results['t_statistic']:.4f}")
        print(f"  P-value: {self.results['p_value']:.6f}")
        print(f"  Significant difference: {'Yes' if self.results['p_value'] < 0.05 else 'No'}")

        # Interpretation
        if self.results['mean_convexity_violation'] > 0 and self.results['p_value'] < 0.05:
            print(f"\n📊 CONCLUSION: Neural ensembles show a statistically significant")
            print(f"   increase in cost compared to LQR ensembles, indicating convexity violation.")
            if self.results['lqr_optimality_gap'] > 1e-9:
                 print(f"   Relative performance loss: {self.results['relative_performance_loss']:.1%}")
            else:
                 print(f"   Absolute performance difference: {self.results['relative_performance_loss']:.4f}")

        elif self.results['mean_convexity_violation'] < 0 and self.results['p_value'] < 0.05:
             print(f"\n📊 CONCLUSION: Neural ensembles perform statistically significantly better")
             print(f"   than LQR ensembles.")
        else:
            print(f"\n📊 CONCLUSION: The difference in performance between Neural and LQR ensembles is not statistically significant.")


def create_visualization(experiment):
    """Create comprehensive visualization of experimental results"""
    results = experiment.results
    all_results = results['all_results']

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # 1. Performance Comparison
    ax1 = axes[0, 0]
    categories = ['Oracle\n(Optimal)', 'LQR\nEnsemble', 'Neural\nEnsemble']
    means = [results['oracle_mean_cost'], results['lqr_mean_cost'], results['neural_mean_cost']]
    colors = ['gold', 'lightblue', 'lightcoral']

    bars = ax1.bar(categories, means, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_ylabel('Mean Episode Cost')
    ax1.set_title('Performance Comparison\n(Lower is Better)')
    ax1.grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, mean in zip(bars, means):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5 * np.sign(mean), # Adjust text position for negative values
                f'{mean:.2f}', ha='center', va='bottom' if mean >= 0 else 'top', fontweight='bold')


    # 2. Convexity Violations Distribution
    ax2 = axes[0, 1]
    all_violations = np.concatenate([r['convexity_violations'] for r in all_results])
    ax2.hist(all_violations, bins=20, alpha=0.7, color='purple', edgecolor='black')
    ax2.axvline(0, color='red', linestyle='--', linewidth=2, label='No Violation')
    ax2.axvline(np.mean(all_violations), color='orange', linestyle='-', linewidth=2,
               label=f'Mean: {np.mean(all_violations):.3f}')
    ax2.set_xlabel('Convexity Violation (Cost Difference)')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Convexity Violations')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 3. Weight Evolution (Example Trial)
    ax3 = axes[0, 2]
    example_trial = all_results[0]
    lqr_weights = np.array(example_trial['lqr_weights'])
    neural_weights = np.array(example_trial['neural_weights'])

    episodes = range(len(lqr_weights))
    for i in range(lqr_weights.shape[1]):
        ax3.plot(episodes, lqr_weights[:, i], '--', alpha=0.7, label=f'LQR Model {i+1}')
        ax3.plot(episodes, neural_weights[:, i], '-', alpha=0.7, label=f'NN Model {i+1}')

    ax3.set_xlabel('Episode')
    ax3.set_ylabel('Ensemble Weight')
    ax3.set_title('Weight Evolution (Example Trial)')
    ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax3.grid(True, alpha=0.3)

    # 4. Optimality Gap Analysis
    ax4 = axes[1, 0]
    gap_data = [results['lqr_optimality_gap'], results['neural_optimality_gap']]
    gap_labels = ['LQR\nOptimality Gap', 'Neural\nOptimality Gap']

    bars = ax4.bar(gap_labels, gap_data, color=['lightblue', 'lightcoral'],
                   alpha=0.7, edgecolor='black')
    ax4.set_ylabel('Optimality Gap (Cost Above Oracle)')
    ax4.set_title('Optimality Gap Comparison')
    ax4.grid(True, alpha=0.3)

    # Add performance loss annotation
    performance_loss = results['relative_performance_loss']
    annotation_text = f'Relative Performance Loss:\n{performance_loss:.1%}' if results['lqr_optimality_gap'] > 1e-9 else f'Performance Difference:\n{performance_loss:.4f}'
    ax4.text(0.5, max(gap_data + [0]) * 0.8, annotation_text, # Ensure text is visible even with negative gaps
            ha='center', va='center', fontsize=12, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7))


    # 5. Statistical Significance
    ax5 = axes[1, 1]
    lqr_costs = np.concatenate([r['lqr_costs'] for r in all_results])
    neural_costs = np.concatenate([r['neural_costs'] for r in all_results])

    ax5.boxplot([lqr_costs, neural_costs], labels=['LQR Ensemble', 'Neural Ensemble'])
    ax5.set_ylabel('Episode Cost')
    ax5.set_title(f'Statistical Comparison\n(p-value: {results["p_value"]:.4f})')
    ax5.grid(True, alpha=0.3)

    # Add significance annotation
    significance = "Significant" if results['p_value'] < 0.05 else "Not Significant"
    # Adjust text position for potential negative costs
    max_cost_val = max(max(lqr_costs), max(neural_costs)) if lqr_costs.size > 0 and neural_costs.size > 0 else 1
    min_cost_val = min(min(lqr_costs), min(neural_costs)) if lqr_costs.size > 0 and neural_costs.size > 0 else 0
    text_y_pos = max_cost_val * 0.9 if max_cost_val > 0 else min_cost_val * 0.9 # Place text above or below depending on cost range

    ax5.text(0.5, text_y_pos,
            f'Difference: {significance}', ha='center', va='center',
            fontsize=12, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='lightgreen' if results['p_value'] < 0.05 else 'lightgray', alpha=0.7))


    # 6. Research Question Summary
    ax6 = axes[1, 2]
    ax6.axis('off')

    summary_text = f"""
RESEARCH QUESTION ANSWER:
How much performance is lost when using
neural network ensembles compared to
LQR ensemble theoretical bounds?

KEY FINDINGS:
• Mean convexity violation: {results['mean_convexity_violation']:.4f}
• Relative performance loss: {results['relative_performance_loss']:.1%}
• Statistical significance: p = {results['p_value']:.4f}

CONCLUSION:
"""

    if results['mean_convexity_violation'] > 0 and results['p_value'] < 0.05:
        conclusion = f"""Neural ensembles show a statistically significant
increase in cost compared to LQR ensembles,
indicating convexity violation."""
        if results['lqr_optimality_gap'] > 1e-9:
            conclusion += f"\nRelative performance loss: {results['relative_performance_loss']:.1%}"
        else:
             conclusion += f"\nAbsolute performance difference: {results['relative_performance_loss']:.4f}"

    elif results['mean_convexity_violation'] < 0 and self.results['p_value'] < 0.05:
         conclusion = """Neural ensembles perform statistically significantly better
than LQR ensembles."""
    else:
        conclusion = """The difference in performance between Neural and LQR ensembles
is not statistically significant."""

    summary_text += conclusion

    ax6.text(0.05, 0.95, summary_text, transform=ax6.transAxes, fontsize=11,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='lightyellow', alpha=0.9))

    plt.tight_layout()
    plt.savefig('convexity_violation_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

# =============================================
# MAIN EXECUTION
# =============================================

def main():
    """Main execution function"""
    print("Convexity Violation Analysis: LQR vs Neural Network Ensembles")
    print("=" * 80)

    # Create and run experiment
    experiment = ConvexityViolationExperiment(
        state_dim=4,
        control_dim=2,
        n_trials=5  # Reduced for demonstration
    )

    results = experiment.run_experiment()

    # Create visualizations
    print("\nGenerating visualizations...")
    create_visualization(experiment)

    print("\nExperiment completed successfully!")
    print("Key research question addressed with statistical rigor.")

    return experiment, results

if __name__ == "__main__":
    experiment, results = main()