#!/usr/bin/env python
"""
Run DANCEST on all available region-day combinations in the dataset

This script identifies all region-day combinations in the [ANONYMIZED]_lp_corrosion.csv
dataset and runs the DANCEST workflow on each, consolidating results for analysis.
"""

import os
import sys
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import subprocess
import json
from concurrent.futures import ProcessPoolExecutor, as_completed

# Add project root to Python path
project_root = Path(__file__).resolve().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

def identify_region_day_combinations():
    """Find all available region-day combinations in the dataset."""
    print("Identifying available region-day combinations...")
    
    combinations = []
    
    # First try [ANONYMIZED]_lp_corrosion.csv
    corrosion_file = Path("[ANONYMIZED]_lp_dataset/[ANONYMIZED]_lp_corrosion.csv")
    if corrosion_file.exists():
        try:
            # For large files, read in chunks
            file_size = os.path.getsize(corrosion_file) / (1024 * 1024)  # Size in MB
            
            if file_size > 100:
                # Process large file in chunks
                print(f"Processing large corrosion file ({file_size:.1f} MB) in chunks...")
                chunks = pd.read_csv(corrosion_file, chunksize=5000)
                
                region_days = set()
                for chunk in chunks:
                    if 'spatial_point' in chunk.columns and 'time_point' in chunk.columns:
                        # Extract unique combinations from this chunk
                        chunk_combos = set(zip(chunk['spatial_point'], chunk['time_point']))
                        region_days.update(chunk_combos)
                
                combinations = list(region_days)
                print(f"Found {len(combinations)} unique region-day combinations")
            else:
                # Read whole file at once
                df = pd.read_csv(corrosion_file)
                if 'spatial_point' in df.columns and 'time_point' in df.columns:
                    # Group by both columns to find unique combinations
                    grouped = df.groupby(['spatial_point', 'time_point']).size().reset_index()
                    combinations = [(row['spatial_point'], int(row['time_point'])) 
                                   for _, row in grouped.iterrows()]
                    print(f"Found {len(combinations)} unique region-day combinations")
        except Exception as e:
            print(f"Error processing corrosion file: {e}")
    
    # If no combinations found, try adapted_test.csv
    if not combinations:
        test_file = Path("adapted_test.csv")
        if test_file.exists():
            try:
                df = pd.read_csv(test_file)
                # Look for region/spatial columns
                region_col = next((col for col in ['region', 'spatial_point', 'spatial'] 
                                 if col in df.columns), None)
                # Look for day/time columns  
                day_col = next((col for col in ['day', 'time_point', 'time']
                              if col in df.columns), None)
                
                if region_col and day_col:
                    # Group by both columns to find unique combinations
                    grouped = df.groupby([region_col, day_col]).size().reset_index()
                    combinations = [(row[region_col], int(row[day_col])) 
                                   for _, row in grouped.iterrows()]
                    print(f"Found {len(combinations)} unique region-day combinations from test file")
            except Exception as e:
                print(f"Error processing test file: {e}")
    
    # If still no combinations found, use default set
    if not combinations:
        print("No data files found with region-day combinations, using defaults")
        combinations = [
            ('s65', 180), ('s123', 210), ('s126', 210), ('s9', 180)
        ]
    
    # Sort by region then day for consistent ordering
    combinations.sort()
    
    return combinations

def run_dancest(region, day):
    """Run DANCEST model for a specific region and day."""
    print(f"Running DANCEST for region {region}, day {day}...")
    
    start_time = time.time()
    cmd = [sys.executable, "DANCEST_model/run_with_agents_verbose.py", 
           "--region", str(region), "--day", str(day)]
    
    try:
        # Run the process and capture output
        process = subprocess.run(
            cmd, 
            stdout=subprocess.PIPE, 
            stderr=subprocess.PIPE,
            text=True,
            check=True
        )
        output = process.stdout
        
        # Parse the output to extract the prediction
        final_val = None
        lines = output.split('\n')
        for line in lines:
            if "corrosion depth =" in line:
                # Extract prediction value
                parts = line.split("corrosion depth =")
                if len(parts) > 1:
                    val_part = parts[1].strip().split()[0]
                    try:
                        final_val = float(val_part)
                    except ValueError:
                        pass
        
        execution_time = time.time() - start_time
        
        return {
            "region": region,
            "day": day,
            "prediction": final_val,
            "execution_time": execution_time,
            "success": True
        }
    except subprocess.CalledProcessError as e:
        print(f"Error running DANCEST for region {region}, day {day}: {e}")
        return {
            "region": region,
            "day": day,
            "prediction": None,
            "execution_time": time.time() - start_time,
            "success": False,
            "error": str(e)
        }

def generate_parallel_report(results):
    """Generate a summary report of all runs."""
    if not results:
        return "No results to report."
    
    # Create output directories if they don't exist
    results_dir = Path("DANCEST_model/batch_results")
    results_dir.mkdir(exist_ok=True)
    
    # Convert to DataFrame for analysis
    df = pd.DataFrame(results)
    
    # Basic statistics
    successful_runs = df[df['success'] == True]
    success_rate = len(successful_runs) / len(df) if len(df) > 0 else 0
    
    # Save detailed results to CSV
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_file = results_dir / f"dancest_batch_results_{timestamp}.csv"
    df.to_csv(csv_file, index=False)
    
    # Create visualization of predictions
    if len(successful_runs) > 0:
        plt.figure(figsize=(12, 8))
        
        # Plot predictions as heatmap if we have enough data
        if len(successful_runs) >= 4:
            # Create a pivot table for the heatmap
            regions = sorted(successful_runs['region'].unique())
            days = sorted(successful_runs['day'].unique())
            
            # Initialize a matrix for the heatmap
            heatmap_data = np.zeros((len(regions), len(days)))
            region_idx = {region: i for i, region in enumerate(regions)}
            day_idx = {day: i for i, day in enumerate(days)}
            
            for _, row in successful_runs.iterrows():
                if row['prediction'] is not None:
                    r_idx = region_idx[row['region']]
                    d_idx = day_idx[row['day']]
                    heatmap_data[r_idx, d_idx] = row['prediction']
            
            plt.imshow(heatmap_data, cmap='viridis')
            plt.colorbar(label='Corrosion Depth (mm)')
            plt.xticks(range(len(days)), days)
            plt.yticks(range(len(regions)), regions)
            plt.xlabel('Day')
            plt.ylabel('Region')
            plt.title('DANCEST Corrosion Depth Predictions')
            
            # Add values to cells
            for i in range(len(regions)):
                for j in range(len(days)):
                    val = heatmap_data[i, j]
                    if val > 0:
                        plt.text(j, i, f'{val:.3f}', ha='center', va='center', 
                                color='white' if val > 0.3 else 'black')
        else:
            # Simple bar chart for fewer predictions
            successful_runs.plot(x='region', y='prediction', kind='bar', figsize=(10, 6))
            plt.title('DANCEST Predictions by Region')
            plt.ylabel('Corrosion Depth (mm)')
            plt.tight_layout()
        
        # Save the plot
        plot_file = results_dir / f"dancest_predictions_{timestamp}.png"
        plt.savefig(plot_file)
        print(f"Saved predictions visualization to {plot_file}")
    
    # Generate report text
    report = f"""
DANCEST Batch Processing Report
===============================
Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

Summary Statistics:
------------------
Total runs: {len(df)}
Successful runs: {len(successful_runs)} ({success_rate:.1%})
Average execution time: {successful_runs['execution_time'].mean():.2f} seconds
Total processing time: {df['execution_time'].sum():.2f} seconds

Results saved to: {csv_file}
"""
    if len(successful_runs) > 0:
        report += f"Visualization saved to: {plot_file}\n"
    
    # Save report to file
    report_file = results_dir / f"dancest_batch_report_{timestamp}.txt"
    with open(report_file, 'w') as f:
        f.write(report)
    
    print(f"Saved batch processing report to {report_file}")
    return report

def main():
    """Main function to run DANCEST on all data."""
    print("Starting DANCEST batch processing on all data...")
    
    # Get all region-day combinations
    combinations = identify_region_day_combinations()
    
    # Ask user if they want to process all combinations
    print(f"\nFound {len(combinations)} region-day combinations to process.")
    print("First 5 combinations:", combinations[:5])
    
    if len(combinations) > 10:
        max_to_process = input(f"Enter maximum number of combinations to process (default: all {len(combinations)}): ")
        if max_to_process.strip() and max_to_process.isdigit():
            combinations = combinations[:int(max_to_process)]
    
    # Ask about parallel processing
    use_parallel = input("Use parallel processing? (y/n, default: y): ").lower() != 'n'
    max_workers = os.cpu_count() - 1 if os.cpu_count() > 1 else 1
    
    if use_parallel:
        workers = input(f"Enter number of parallel workers (default: {max_workers}): ")
        if workers.strip() and workers.isdigit():
            max_workers = int(workers)
    
    start_time = time.time()
    results = []
    
    if use_parallel:
        print(f"\nProcessing {len(combinations)} combinations with {max_workers} parallel workers...")
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            future_to_combo = {
                executor.submit(run_dancest, region, day): (region, day)
                for region, day in combinations
            }
            
            for i, future in enumerate(as_completed(future_to_combo)):
                region, day = future_to_combo[future]
                try:
                    result = future.result()
                    results.append(result)
                    print(f"Completed {i+1}/{len(combinations)}: {region}, day {day}")
                except Exception as e:
                    print(f"Error processing {region}, day {day}: {e}")
                    results.append({
                        "region": region,
                        "day": day,
                        "prediction": None,
                        "execution_time": 0,
                        "success": False,
                        "error": str(e)
                    })
    else:
        print(f"\nProcessing {len(combinations)} combinations sequentially...")
        for i, (region, day) in enumerate(combinations):
            print(f"Processing {i+1}/{len(combinations)}: {region}, day {day}")
            result = run_dancest(region, day)
            results.append(result)
    
    total_time = time.time() - start_time
    
    # Generate and print report
    print("\n" + "="*80)
    print(f"Completed processing {len(combinations)} combinations in {total_time:.2f} seconds")
    report = generate_parallel_report(results)
    print("\n" + report)

if __name__ == "__main__":
    main() 