"""
Simple demonstration of DANCEST fusion using trained models.
Shows how neural and symbolic predictions are combined based on uncertainty.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import json
import sys
from datetime import datetime
from pathlib import Path

# Add project root to Python path
project_root = Path(__file__).resolve().parent.parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

# Import models
from DANCEST_model.real_models import build_real_estimators

def main():
    # Create output directory
    output_dir = Path('./reports/Phase2_reports')
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Build estimators
    print("Loading neural and symbolic models...")
    neural_estimator, symbolic_estimator = build_real_estimators()
    
    # Create sample data points
    print("Creating test points...")
    
    # Create sample vertices (blades) with different properties
    vertices = []
    for i in range(4):
        # Different alloy types
        alloy_types = ['Inconel-718', 'Rene-77', 'GTD-111', 'Waspaloy']
        alloy = alloy_types[i % len(alloy_types)]
        
        # Different coatings
        coatings = ['None', 'Type-A', 'Type-B', 'Type-C']
        coating = coatings[i % len(coatings)]
        
        # Create blade vertex
        vertices.append({
            'type': 'blade',
            'blade_id': f'blade_{i}',
            'alloy_type': alloy,
            'initial_thickness_mm': 3.5,
            'chromium_content_pct': 18.0,
            'surface_coating': coating
        })
    
    # Time points to evaluate
    time_points = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    
    # Store all results
    all_results = {
        'time_points': time_points,
        'neural_predictions': [],
        'neural_uncertainties': [],
        'symbolic_predictions': [],
        'symbolic_uncertainties': [],
        'fusion_weights': [],
        'fused_predictions': []
    }
    
    # Process each time point
    for t in time_points:
        print(f"Processing time point t={t}")
        
        # Get neural predictions and uncertainties
        neural_preds, neural_vars = neural_estimator(vertices, t)
        
        # Get symbolic predictions and uncertainties
        symbolic_preds, symbolic_vars = symbolic_estimator(vertices, t)
        
        # Calculate DANCEST fusion weights: Ω = σ²_s / (σ²_n + σ²_s)
        fusion_weights = symbolic_vars / (neural_vars + symbolic_vars)
        
        # Get minimum symbolic weight (default: 0.4)
        min_symbolic_weight = 0.4
        try:
            # Check if we have calibration factors with minimum weight
            calibration_path = Path('./DANCEST_model/results/calibration_factors.json')
            if calibration_path.exists():
                with open(calibration_path, 'r') as f:
                    calibration = json.load(f)
                    min_symbolic_weight = calibration.get('min_symbolic_weight', 0.4)
        except Exception:
            pass
            
        # Apply minimum symbolic weight constraint
        max_neural_weight = 1.0 - min_symbolic_weight
        fusion_weights = np.minimum(fusion_weights, max_neural_weight)
        
        print(f"Symbolic model weight: {1-fusion_weights.mean():.4f}")
        
        # Apply fusion: f* = Ω·f_n + (1-Ω)·f_s
        fused_preds = fusion_weights * neural_preds + (1 - fusion_weights) * symbolic_preds
        
        # Store results for this time point
        all_results['neural_predictions'].append(neural_preds.tolist())
        all_results['neural_uncertainties'].append(np.sqrt(neural_vars).tolist())  # Convert variance to std
        all_results['symbolic_predictions'].append(symbolic_preds.tolist())
        all_results['symbolic_uncertainties'].append(np.sqrt(symbolic_vars).tolist())  # Convert variance to std
        all_results['fusion_weights'].append(fusion_weights.tolist())
        all_results['fused_predictions'].append(fused_preds.tolist())
    
    # Save results to file
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    result_path = output_dir / f'dancest_fusion_demo_{timestamp}.json'
    with open(result_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"Saved results to {result_path}")
    
    # Calculate average metrics
    avg_neural_weight = np.mean(np.array(all_results['fusion_weights']))
    avg_symbolic_weight = 1 - avg_neural_weight
    
    print("\nDANCEST Fusion Results:")
    print(f"Average neural weight (Ω): {avg_neural_weight:.4f}")
    print(f"Average symbolic weight (1-Ω): {avg_symbolic_weight:.4f}")
    
    # Plot results
    create_visualizations(all_results, output_dir, timestamp)
    
    return all_results

def create_visualizations(results, output_dir, timestamp):
    """Create visualizations of the fusion process."""
    # Plot the fusion weights for each time point
    plt.figure(figsize=(12, 8))
    
    # Select a middle time point to display
    t_idx = len(results['time_points']) // 2
    t = results['time_points'][t_idx]
    
    # Extract data for this time point
    neural_preds = np.array(results['neural_predictions'][t_idx])
    neural_stds = np.array(results['neural_uncertainties'][t_idx])
    symb_preds = np.array(results['symbolic_predictions'][t_idx])
    symb_stds = np.array(results['symbolic_uncertainties'][t_idx])
    weights = np.array(results['fusion_weights'][t_idx])
    fused_preds = np.array(results['fused_predictions'][t_idx])
    
    # Plot predictions with uncertainties
    plt.subplot(2, 1, 1)
    indices = range(len(neural_preds))
    plt.errorbar(indices, neural_preds, yerr=neural_stds, fmt='o', label='Neural', alpha=0.7)
    plt.errorbar(indices, symb_preds, yerr=symb_stds, fmt='s', label='Symbolic', alpha=0.7)
    plt.plot(indices, fused_preds, 'x', markersize=8, label='Fused', alpha=0.9)
    plt.xlabel('Blade Index')
    plt.ylabel('Corrosion Depth (mm)')
    plt.title(f'Neural vs Symbolic vs Fused Predictions at t={t}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot fusion weights
    plt.subplot(2, 1, 2)
    plt.bar(indices, weights, alpha=0.7, label='Neural Weight (Ω)')
    plt.bar(indices, 1-weights, bottom=weights, alpha=0.7, label='Symbolic Weight (1-Ω)')
    plt.xlabel('Blade Index')
    plt.ylabel('Weight')
    plt.title(f'Fusion Weights: Ω*(s,t) = σ²_s / (σ²_n + σ²_s) at t={t}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_dir / f'dancest_fusion_plot_{timestamp}.png', dpi=300)
    print(f"Saved visualization to {output_dir}/dancest_fusion_plot_{timestamp}.png")
    
    # Plot predictions over time
    plt.figure(figsize=(12, 8))
    time_points = results['time_points']
    n_blades = len(results['neural_predictions'][0])
    
    # Average predictions across all blades
    avg_neural = [np.mean(np.array(preds)) for preds in results['neural_predictions']]
    avg_symbolic = [np.mean(np.array(preds)) for preds in results['symbolic_predictions']]
    avg_fused = [np.mean(np.array(preds)) for preds in results['fused_predictions']]
    
    plt.plot(time_points, avg_neural, 'o-', label='Neural Model')
    plt.plot(time_points, avg_symbolic, 's-', label='Symbolic Model')
    plt.plot(time_points, avg_fused, 'x-', linewidth=2, label='DANCE-ST Fusion')
    plt.xlabel('Time Point')
    plt.ylabel('Average Corrosion Depth (mm)')
    plt.title('Corrosion Progression Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_dir / f'dancest_time_plot_{timestamp}.png', dpi=300)
    print(f"Saved time series visualization to {output_dir}/dancest_time_plot_{timestamp}.png")

if __name__ == "__main__":
    main() 