import os
import sys
import gc
import torch
import wandb
from ultralytics import YOLO
os.system('clear')

path = '/home/ubuntu/thesis-Intersection/yolo'
yolo_model = 'yolo12'
# variants = ['n', 's',  'm', 'l'] # , 'x']
variants = ['m']
batch_size = 16  # Default batch size



# Add CUDA error checking and memory management
def clear_gpu_memory():
    """Clear GPU memory and cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

def train_all_variants(variant):
    global path, yolo_model, batch_size
    try:
        # Clear GPU memory before starting
        clear_gpu_memory()
        model_name = f"{yolo_model}{variant}"
        
        # Initialize WandB for training
        train_run = wandb.init(project="Road-Intersection-rtx2080", name=f"{model_name}-training-{batch_size}")

        # Set save directory
        save_dir = os.path.join(path, 'run/baseline')
        run_name = model_name
        run_dir = os.path.join(save_dir, run_name)

        # Load model and train with optimized parameters for your dataset
        model = YOLO(f"{model_name}.pt")
        results = model.train(
            data=os.path.join(path, "data/data.yaml"),
            epochs=100,
            patience=15,
            close_mosaic=10,  # Optimal value for stability

            batch=batch_size,   # Auto batch size using 85% GPU memory (more efficient than fixed 16)
            imgsz=640,
            project=run_dir,
            name=run_name,
            save=True,
            verbose=True,  # Enable verbose logging
            resume=False,

            # Regularization
            dropout=0.2,
            weight_decay=0.0005,  # L2 regularization
            
            # Learning rate optimization
            lr0=0.01,    # Initial learning rate
            lrf=0.01,    # Final learning rate factor
            
            # Class balancing (intersection:roundabout = 770:585 = 1.32:1)
            cls=0.5,     # Classification loss weight
            box=7.5,     # Box loss weight
            
            # Optimized augmentation for your dataset size
            mosaic=0.7,        # Increased for small dataset
            mixup=0.15,        # Slightly reduced
            erasing=0.4,       # Random erasing
            degrees=10,        # Reduced rotation for road scenes
            shear=5,           # Reduced shear
            scale=0.2,         # Reduced scale variation
            translate=0.1,     # Reduced translation
            copy_paste=0.3,    # Reduced copy-paste


        )
        
        # Finish training run
        train_run.finish()
        
        # Clear memory after training
        del model
        del results
        clear_gpu_memory()
        
    except Exception as e:
        print(f"Error training {model_name}: {e}")
        # Ensure WandB run is finished even if error occurs
        if 'train_run' in locals():
            train_run.finish()
        # Clear memory on error
        clear_gpu_memory()
        raise

# Run the preferred option
if __name__ == "__main__":
    # Set environment variables for better CUDA debugging and stability
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    os.environ['TORCH_USE_CUDA_DSA'] = '1'
    os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'  # Limit CUDA connections
    os.environ['PYTHONUNBUFFERED'] = '1'  # Force immediate output
    
    # Limit memory fragmentation
    torch.cuda.empty_cache()
    
    for variant in variants:
        train_all_variants(variant)
