"""
Run NMPC Optimization and Extract Dual Variables (Lagrangian Multipliers)

This script:
1. Loads trained TEP emulator and forecaster
2. Runs NMPC optimization on historical data
3. Extracts dual variables (λ) for all constraints
4. Saves results with dual columns for HCA analysis
"""

import numpy as np
import pandas as pd
import sys
from pathlib import Path
import logging

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent))

from TEP_nmpc_controller import TEPEmulator, TEPDisturbanceForecaster, TEPMPC
from config import *

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def extract_dual_variables_from_solution(mpc: TEPMPC, sol_dict: dict, k: int) -> dict:
    """
    Extract and organize dual variables from MPC solution.
    
    The constraint structure in _build_mpc_problem() is:
    1. Dynamics constraints (nx per stage)
    2. State MIN constraints (with slack)
    3. State MAX constraints (with slack)
    4. Control bounds
    
    Args:
        mpc: TEPMPC controller instance
        sol_dict: Solution dictionary from mpc.solve()
        k: Current time step
    
    Returns:
        Dictionary of dual variables for this time step
    """
    dual_vars = {}
    lam_g = sol_dict['dual_variables']
    
    # Calculate constraint indices
    n_dynamics = mpc.nx * mpc.N  # Dynamics constraints
    
    # State constraints (MIN and MAX for each state at each horizon step)
    state_vars = DATA_CONFIG['data_columns']['state_vars']
    
    # Count constraints
    idx = n_dynamics  # Start after dynamics constraints
    
    # For the FIRST stage (k=1), extract duals for state constraints
    # Since we apply control at k=0, we look at resulting state at k=1
    
    for i, state_var in enumerate(state_vars):
        metadata = STATE_METADATA.get(state_var, {})
        x_min, x_max = metadata.get('safe_range', (-np.inf, np.inf))
        
        # MIN constraint dual
        if x_min > -np.inf:
            if idx < len(lam_g):
                dual_vars[f'dual_min_{state_var}'] = lam_g[idx]
                idx += 1
            else:
                dual_vars[f'dual_min_{state_var}'] = 0.0
        else:
            dual_vars[f'dual_min_{state_var}'] = 0.0
        
        # MAX constraint dual  
        if x_max < np.inf:
            if idx < len(lam_g):
                dual_vars[f'dual_max_{state_var}'] = lam_g[idx]
                idx += 1
            else:
                dual_vars[f'dual_max_{state_var}'] = 0.0
        else:
            dual_vars[f'dual_max_{state_var}'] = 0.0
    
    return dual_vars


def run_nmpc_with_duals(input_csv: str, output_csv: str, 
                        n_steps: int = None, 
                        subsample: int = 1):
    """
    Run NMPC optimization and save results with dual variables.
    
    Args:
        input_csv: Path to processed TEP data
        output_csv: Path to save results with duals
        n_steps: Number of time steps to simulate (None = all)
        subsample: Process every Nth sample (1 = all, 10 = every 10th)
    """
    logger.info("="*80)
    logger.info("NMPC Dual Variable Generation")
    logger.info("="*80)
    
    # Load data
    logger.info(f"Loading data from {input_csv}...")
    df = pd.read_csv(DATA_DIR / input_csv, parse_dates=['timestamp'])
    
    if n_steps is not None:
        df = df.iloc[:n_steps]
    
    if subsample > 1:
        df = df.iloc[::subsample]
        logger.info(f"Subsampling: Using every {subsample} samples")
    
    logger.info(f"Data shape: {df.shape}")
    logger.info(f"Time range: {df['timestamp'].min()} to {df['timestamp'].max()}")
    
    # Check if models exist
    emulator_path = MODEL_DIR / "emulator.pth"
    if not emulator_path.exists():
        logger.error("❌ No trained emulator found!")
        logger.error("Please run: python TEP_nmpc_controller.py (train models first)")
        return None
    
    # Load models
    logger.info("Loading trained models...")
    emulator = TEPEmulator()
    emulator.load_model(MODEL_DIR)
    
    forecaster = TEPDisturbanceForecaster()
    forecaster.load_models(MODEL_DIR)
    
    # Initialize MPC
    mpc = TEPMPC(emulator, forecaster)
    logger.info(f"✅ MPC initialized with horizon N={mpc.N}")
    
    # Prepare columns
    state_cols = DATA_CONFIG['data_columns']['state_vars']
    input_cols = DATA_CONFIG['data_columns']['input_vars']
    disturbance_cols = DATA_CONFIG['data_columns']['disturbance_vars']
    
    # Initialize results storage
    results = {
        'timestamp': [],
        'step': [],
        'cost': [],
        'success': []
    }
    
    # Add state, input, disturbance columns
    for col in state_cols:
        results[f'{col}_actual'] = []
        results[f'{col}_predicted'] = []
    
    for col in input_cols:
        results[f'u_nmpc_{col}'] = []
    
    for col in disturbance_cols:
        results[col] = []
    
    # Add dual variable columns
    for state_var in state_cols:
        results[f'dual_min_{state_var}'] = []
        results[f'dual_max_{state_var}'] = []
    
    # Run NMPC simulation
    logger.info("\nStarting NMPC closed-loop simulation...")
    logger.info(f"Total steps: {len(df) - mpc.N}")
    
    u_prev = np.ones(mpc.nu) * 50.0  # Initialize at 50% valve opening
    
    for k in range(len(df) - mpc.N - 1):
        if k % 100 == 0:
            logger.info(f"Step {k}/{len(df) - mpc.N - 1} ({k/(len(df)-mpc.N-1)*100:.1f}%)")
        
        # Current state
        x0 = df[state_cols].iloc[k].values
        
        # Reference trajectory (setpoints)
        x_ref = np.tile(
            np.array([STATE_METADATA[v]['setpoint'] for v in state_cols]),
            (mpc.N + 1, 1)
        ).T
        
        # Disturbance forecast
        d_current = df[disturbance_cols].iloc[k:k+mpc.N].values.T
        if d_current.shape[1] < mpc.N:
            # Pad if not enough data
            d_pad = np.tile(d_current[:, -1:], (1, mpc.N - d_current.shape[1]))
            d_current = np.hstack([d_current, d_pad])
        
        try:
            # Solve MPC
            sol = mpc.solve(x0, u_prev, x_ref, d_current)
            
            if not sol['success']:
                logger.warning(f"Step {k}: Solver failed!")
            
            # Extract first control action
            u_opt = sol['U_opt'][:, 0]
            u_prev = u_opt
            
            # Extract dual variables
            duals = extract_dual_variables_from_solution(mpc, sol, k)
            
            # Store results
            results['timestamp'].append(df['timestamp'].iloc[k])
            results['step'].append(k)
            results['cost'].append(sol['cost'])
            results['success'].append(sol['success'])
            
            # Store states
            for i, col in enumerate(state_cols):
                results[f'{col}_actual'].append(x0[i])
                results[f'{col}_predicted'].append(sol['X_opt'][i, 1])
            
            # Store controls
            for i, col in enumerate(input_cols):
                results[f'u_nmpc_{col}'].append(u_opt[i])
            
            # Store disturbances
            for i, col in enumerate(disturbance_cols):
                results[col].append(df[col].iloc[k])
            
            # Store dual variables
            for state_var in state_cols:
                results[f'dual_min_{state_var}'].append(duals.get(f'dual_min_{state_var}', 0.0))
                results[f'dual_max_{state_var}'].append(duals.get(f'dual_max_{state_var}', 0.0))
        
        except Exception as e:
            logger.error(f"Step {k}: Error - {e}")
            # Fill with NaNs
            results['timestamp'].append(df['timestamp'].iloc[k])
            results['step'].append(k)
            results['cost'].append(np.nan)
            results['success'].append(False)
            
            for col in state_cols:
                results[f'{col}_actual'].append(np.nan)
                results[f'{col}_predicted'].append(np.nan)
            
            for col in input_cols:
                results[f'u_nmpc_{col}'].append(np.nan)
            
            for col in disturbance_cols:
                results[col].append(np.nan)
            
            for state_var in state_cols:
                results[f'dual_min_{state_var}'].append(0.0)
                results[f'dual_max_{state_var}'].append(0.0)
    
    # Convert to DataFrame
    logger.info("\nConverting results to DataFrame...")
    df_results = pd.DataFrame(results)
    
    # Statistics
    logger.info("\n" + "="*80)
    logger.info("RESULTS SUMMARY")
    logger.info("="*80)
    logger.info(f"Total steps completed: {len(df_results)}")
    logger.info(f"Successful optimizations: {df_results['success'].sum()} ({df_results['success'].sum()/len(df_results)*100:.1f}%)")
    logger.info(f"Average cost: {df_results['cost'].mean():.2f}")
    
    logger.info("\nDual Variable Statistics:")
    dual_cols = [c for c in df_results.columns if c.startswith('dual_')]
    for col in dual_cols:
        active_count = (df_results[col].abs() > 1e-7).sum()
        logger.info(f"  {col}:")
        logger.info(f"    Mean: {df_results[col].mean():.2e}")
        logger.info(f"    Max:  {df_results[col].abs().max():.2e}")
        logger.info(f"    Active (|λ| > 1e-7): {active_count} / {len(df_results)} ({active_count/len(df_results)*100:.1f}%)")
    
    # Save results
    output_path = DATA_DIR / output_csv
    logger.info(f"\nSaving results to {output_path}...")
    df_results.to_csv(output_path, index=False)
    logger.info("✅ Done!")
    
    return df_results


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Run NMPC and generate dual variables')
    parser.add_argument('--input', default='processed_train_run1.csv', help='Input CSV file')
    parser.add_argument('--output', default='processed_train_run1_nmpc_duals.csv', help='Output CSV file')
    parser.add_argument('--steps', type=int, default=500, help='Number of steps (None=all)')
    parser.add_argument('--subsample', type=int, default=1, help='Subsample rate (1=all, 10=every 10th)')
    
    args = parser.parse_args()
    
    print("\n" + "="*80)
    print("TEP NMPC Dual Variable Generator")
    print("="*80)
    print(f"Input:  {args.input}")
    print(f"Output: {args.output}")
    print(f"Steps:  {args.steps if args.steps else 'All'}")
    print(f"Subsample: {args.subsample}")
    print("="*80 + "\n")
    
    df_results = run_nmpc_with_duals(
        input_csv=args.input,
        output_csv=args.output,
        n_steps=args.steps,
        subsample=args.subsample
    )
    
    if df_results is not None:
        print("\n" + "="*80)
        print("✅ SUCCESS! Dual variables generated")
        print("="*80)
        print(f"\nOutput file: {DATA_DIR / args.output}")
        print(f"Total columns: {len(df_results.columns)}")
        print(f"Dual columns: {len([c for c in df_results.columns if c.startswith('dual_')])}")
        print("\nYou can now use this data for ablation study:")
        print(f"  cp data/{args.output} data/processed_train_run1.csv")
        print(f"  python run_real_tep_ablation.py")
