#!/usr/bin/env python
"""
Constraint Projection Audit Tool for DANCE-ST.

This script audits the behavior of the Phase III constraint projection by
generating a set of sample predictions and applying the box constraint projection.
It visually demonstrates that predictions violating the physical constraints
are correctly projected onto the feasible space, while valid predictions
remain unchanged, directly verifying the implementation in `Core/dr_solver.py`.

USAGE:
  python analysis/constraint_projection_audit.py [--plot OUTPUT.png]
"""

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

# Add the parent directory to the path to allow importing from Core
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

try:
    from Core.dr_solver import project_onto_box_constraints
except ImportError:
    print("Error: Could not import 'project_onto_box_constraints' from Core.dr_solver.", file=sys.stderr)
    print("Please ensure the script is run from the project's root directory or the 'analysis' folder.", file=sys.stderr)
    sys.exit(1)


def create_sample_predictions(num_points=1000, center=1000, spread=500):
    """Generates a sample of intermediate predictions (f_int)."""
    # Create a distribution of predictions, many of which will violate the constraints
    return np.random.normal(loc=center, scale=spread, size=num_points)


def run_projection_audit(predictions, lower_bound, upper_bound):
    """Applies the projection and returns the projected values."""
    print(f"Running projection audit on {len(predictions)} data points.")
    print(f"Constraints: lower_bound = {lower_bound}, upper_bound = {upper_bound}")
    
    projected_predictions = project_onto_box_constraints(
        f_intermediate=predictions,
        lower_bound=lower_bound,
        upper_bound=upper_bound
    )
    
    return projected_predictions


def main():
    """Run the constraint projection audit."""
    parser = argparse.ArgumentParser(description="Audit the behavior of the box constraint projection.")
    parser.add_argument("--plot", help="Save a visualization of the projection to this path (e.g., projection_audit.png)")
    args = parser.parse_args()

    # Define constraints based on the Turbine-500 example in the paper
    LOWER_BOUND = 0.0
    UPPER_BOUND = 1200.0

    # 1. Generate sample data
    # We center the data around a typical operating temperature, with a large spread
    # to ensure many points fall outside the [0, 1200] feasible range.
    sample_f_int = create_sample_predictions(num_points=5000, center=950, spread=400)

    # 2. Run the projection
    projected_f = run_projection_audit(sample_f_int, LOWER_BOUND, UPPER_BOUND)

    # 3. Analyze and verify the results
    violations_before = np.sum((sample_f_int < LOWER_BOUND) | (sample_f_int > UPPER_BOUND))
    violations_after = np.sum((projected_f < LOWER_BOUND) | (projected_f > UPPER_BOUND))
    
    print("\n--- Audit Summary ---")
    print(f"Constraint Violations Before Projection: {violations_before} / {len(sample_f_int)} ({violations_before/len(sample_f_int):.1%})")
    print(f"Constraint Violations After Projection:  {violations_after} / {len(projected_f)}")

    if violations_after == 0:
        print("\n[SUCCESS] All projected values are within the defined constraint boundaries.")
    else:
        print("\n[FAILURE] The projection failed to enforce all constraints.")

    # 4. Create visualization if requested
    if args.plot:
        plt.style.use('seaborn-v0_8-whitegrid')
        plt.figure(figsize=(12, 7))
        
        plt.scatter(sample_f_int, projected_f, c=projected_f, cmap='viridis', alpha=0.6, s=20)
        
        # Plot the theoretical line for unchanged values
        x_line = np.linspace(LOWER_BOUND, UPPER_BOUND, 100)
        plt.plot(x_line, x_line, 'r--', linewidth=2, label='Identity (No Projection)')
        
        # Highlight constraint boundaries
        plt.axhline(y=LOWER_BOUND, color='k', linestyle=':', label=f'Lower Bound ({LOWER_BOUND})')
        plt.axhline(y=UPPER_BOUND, color='k', linestyle=':', label=f'Upper Bound ({UPPER_BOUND})')
        plt.axvline(x=LOWER_BOUND, color='k', linestyle=':')
        plt.axvline(x=UPPER_BOUND, color='k', linestyle=':')
        
        plt.xlabel("Intermediate Prediction ($f_{int}$)", fontsize=12)
        plt.ylabel("Projected Prediction ($f_{proj}$)", fontsize=12)
        plt.title("Verification of Constraint Projection Behavior", fontsize=14, weight='bold')
        plt.legend(fontsize=10)
        plt.colorbar(label='Projected Value')
        plt.grid(True)
        plt.tight_layout()
        
        plt.savefig(args.plot)
        print(f"\nSaved visualization to {args.plot}")


if __name__ == "__main__":
    main()
