import os
import sys
import gc
import torch
import wandb
from ultralytics import YOLO
import numpy as np
import glob
os.system('clear')

model_name = 'yolov11s'

path = "/home/ubuntu/thesis-Intersection/yolo/"
data_path = os.path.join(path, "data", "data.yaml")
# model_path = os.path.join(path, "run", "baseline", model_name, "yolov8n-8", "weights", "best.pt")
model_path = "/home/ubuntu/thesis-Intersection/yolo/run/baseline/yolo11s/yolo11s-16/weights/best.pt"
save_dir = os.path.join(path, "run", "val", model_name)

# Initialize WandB run that YOLO will use
val_run = wandb.init(project="Road-Intersection-rtx2080-Validation", name=f"{model_name}-validation-16-new")

# Load model
model = YOLO(model_path)

# Number of runs for timing analysis
num_runs = 11  # Increased to 11 to skip the first run for JIT warm-up

# Lists to store speeds and metrics from each run
preprocess_times = []
inference_times = []
postprocess_times = []
total_times = []

# Lists to store performance metrics from each run
map50_95_values = []
map50_values = []
precision_values = []
recall_values = []
f1_values = []

# Run validation multiple times for timing and performance analysis
for i in range(num_runs):
    print(f"Running validation {i+1}/{num_runs}...")
    
    # For the last run, enable full logging to get plots and visualizations
    if i == num_runs - 1:
        print("Final run - enabling full logging with plots...")
        # Use the exact approach from YOLO documentation for generating plots
        val_results = model.val(
            data=data_path, 
            batch=1, 
            split='test', 
            plots=True,        # This should generate confusion matrix, PR curves, etc.
            save_json=True,    # Save results to JSON
            save_txt=True,      # Save results to TXT
            save_conf=True,    # Save confidence scores
            save=True,         # Save images with predictions
            verbose=True,      # Show detailed output
            project=save_dir,  # Specify project directory
            name=model_name  # Specific name for this run
        )
        print(f"Final run completed. Results saved to: {val_results.save_dir if hasattr(val_results, 'save_dir') else 'Unknown'}")
        
        # Check if plots were generated and upload them to WandB
        if hasattr(val_results, 'save_dir') and val_results.save_dir:
            results_save_dir = str(val_results.save_dir)
            print(f"Looking for plots in: {results_save_dir}")
            
            # Look for generated plots and upload them
            plot_files = glob.glob(f"{results_save_dir}/*.png") + glob.glob(f"{results_save_dir}/*.jpg")
            print(f"Found {len(plot_files)} plot files: {plot_files}")
            
            if len(plot_files) > 0:
                for plot_file in plot_files:
                    filename = os.path.basename(plot_file)
                    if 'confusion_matrix' in filename:
                        val_run.log({"validation_plots/confusion_matrix": wandb.Image(plot_file)})
                    elif 'F1_curve' in filename:
                        val_run.log({"validation_plots/f1_curve": wandb.Image(plot_file)})
                    elif 'P_curve' in filename:
                        val_run.log({"validation_plots/precision_curve": wandb.Image(plot_file)})
                    elif 'R_curve' in filename:
                        val_run.log({"validation_plots/recall_curve": wandb.Image(plot_file)})
                    elif 'PR_curve' in filename:
                        val_run.log({"validation_plots/pr_curve": wandb.Image(plot_file)})
                    elif 'results' in filename:
                        val_run.log({"validation_plots/results": wandb.Image(plot_file)})
                    else:
                        # Log any other plots with generic names
                        clean_name = filename.replace('.png', '').replace('.jpg', '')
                        val_run.log({f"validation_plots/{clean_name}": wandb.Image(plot_file)})
                        
                print(f"Uploaded {len(plot_files)} plots to WandB")
            else:
                print("Warning: No plot files were generated despite plots=True")
                # List all files in the directory for debugging
                all_files = glob.glob(f"{results_save_dir}/*")
                print(f"All files in save directory: {all_files}")
    else:
        # For intermediate runs, disable saves and plots to avoid clutter
        val_results = model.val(data=data_path, batch=1, split='test', save=False, plots=False, verbose=False)
    
    speed = val_results.speed
    
    # skip first iteration for JIT
    if i == 0:
        continue
    
    # Collect timing data
    preprocess = speed['preprocess']
    inference = speed['inference']
    postprocess = speed['postprocess']
    total_time = preprocess + inference + postprocess
    
    preprocess_times.append(preprocess)
    inference_times.append(inference)
    postprocess_times.append(postprocess)
    total_times.append(total_time)
    
    # Collect performance metrics
    metrics = val_results.results_dict
    map50_95_values.append(metrics.get('metrics/mAP50-95(B)', 0))
    map50_values.append(metrics.get('metrics/mAP50(B)', 0))
    precision_values.append(metrics.get('metrics/precision(B)', 0))
    recall_values.append(metrics.get('metrics/recall(B)', 0))
    
    # Calculate F1 score
    p = metrics.get('metrics/precision(B)', 0)
    r = metrics.get('metrics/recall(B)', 0)
    f1 = 2 * (p * r) / (p + r) if (p + r) > 0 else 0
    f1_values.append(f1)

# Compute averages
avg_inference = sum(inference_times) / len(inference_times)
avg_map50_95 = sum(map50_95_values) / len(map50_95_values)
avg_map50 = sum(map50_values) / len(map50_values)
avg_precision = sum(precision_values) / len(precision_values)
avg_recall = sum(recall_values) / len(recall_values)
avg_f1 = sum(f1_values) / len(f1_values)

# Get the last run's individual metrics to preserve them
last_run_metrics = val_results.results_dict
last_run_individual = {
    "last_run/mAP50-95": last_run_metrics.get('metrics/mAP50-95(B)', 0),
    "last_run/mAP50": last_run_metrics.get('metrics/mAP50(B)', 0),
    "last_run/precision": last_run_metrics.get('metrics/precision(B)', 0),
    "last_run/recall": last_run_metrics.get('metrics/recall(B)', 0),
    "last_run/inference_ms": inference_times[-1],  # Last inference time
}

# Log both averaged metrics and last run individual metrics
combined_data = {
    # Averaged metrics
    "avg_metrics/mAP50-95": avg_map50_95,
    "avg_metrics/mAP50": avg_map50,
    "avg_metrics/precision": avg_precision,
    "avg_metrics/recall": avg_recall,
    "avg_metrics/f1": avg_f1,
    "speed/avg_inference_ms": avg_inference,
    # # Last run individual metrics (preserved separately)
    # **last_run_individual
}

val_run.log(combined_data)

# Print summary
print(f"\n=== Average Performance Metrics (over {num_runs-1} runs) ===")
print(f"mAP50-95: {avg_map50_95:.4f}")
print(f"mAP50:    {avg_map50:.4f}")
print(f"Precision: {avg_precision:.4f}")
print(f"Recall:    {avg_recall:.4f}")
print(f"F1:        {avg_f1:.4f}")
print(f"Inference: {avg_inference:.1f}ms")

print(f"\n=== Last Run Individual Metrics ===")
print(f"mAP50-95: {last_run_individual['last_run/mAP50-95']:.4f}")
print(f"mAP50:    {last_run_individual['last_run/mAP50']:.4f}")
print(f"Precision: {last_run_individual['last_run/precision']:.4f}")
print(f"Recall:    {last_run_individual['last_run/recall']:.4f}")
print(f"Inference: {last_run_individual['last_run/inference_ms']:.1f}ms")

# Finish validation run
val_run.finish()
print(model_path)
