import os
import time
import subprocess
from datetime import timedelta
import pandas as pd

# ==== Configuration ====

MODELS = [
    'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt',
    'yolov9t.pt', 'yolov9s.pt', 'yolov9m.pt', 'yolov9c.pt', 'yolov9e.pt',
    'yolov10n.pt', 'yolov10s.pt', 'yolov10m.pt', 'yolov10l.pt', 'yolov10x.pt',
    'yolo11n.pt', 'yolo11s.pt', 'yolo11m.pt', 'yolo11l.pt', 'yolo11x.pt'
]

DATA = 'yolodata/YOLODataset/dataset.yaml'
EPOCHS = 100
BATCH = 16
IMGSZ = 640
DEVICE = '0'
OUT_CSV = 'yolo_train_summary.csv'

# ==== Utilities ====

def get_file_size_mb(path):
    return round(os.path.getsize(path) / 1024 / 1024, 2)

def extract_best_losses(results_path):
    try:
        df = pd.read_csv(results_path)
        idx = df['metrics/mAP50(B)'].idxmax()
        return round(df.loc[idx, 'loss/box'], 4), round(df.loc[idx, 'loss/dfl'], 4)
    except:
        return 'N/A', 'N/A'

def parse_metrics(output):
    mAP50, mAP5095 = 'N/A', 'N/A'
    for line in output.splitlines():
        if line.strip().startswith('all') and len(line.split()) >= 7:
            try:
                parts = line.split()
                mAP50 = parts[-2]
                mAP5095 = parts[-1]
            except:
                pass
            break
    return mAP50, mAP5095

# ==== Training Process ====

def run_training(model):
    name = os.path.splitext(os.path.basename(model))[0]
    print(f'\\n🚀 Training model: {model}')

    cmd = [
        'yolo', 'detect', 'train',
        f'model={model}',
        f'data={DATA}',
        f'epochs={EPOCHS}',
        f'batch={BATCH}',
        f'imgsz={IMGSZ}',
        f'device={DEVICE}',
        f'name={name}',
        'save=True',
        'plots=True',
        'verbose=True'
    ]

    start_time = time.time()
    subprocess.run(cmd, check=True)
    duration = timedelta(seconds=round(time.time() - start_time))
    best_path = os.path.join('runs/detect', name, 'weights', 'best.pt')
    model_size = f"{get_file_size_mb(best_path)} MB" if os.path.exists(best_path) else "N/A"

    print(f'\\n✅ Validating model: {best_path}')
    cmd_val = [
        'yolo', 'detect', 'val',
        f'model={best_path}',
        f'data={DATA}',
    ]
    t0 = time.time()
    result = subprocess.run(cmd_val, capture_output=True, text=True)
    t1 = time.time()
    infer_time = f"{round((t1 - t0), 3)} s"

    mAP50, mAP5095 = parse_metrics(result.stdout)
    results_path = os.path.join('runs/detect', name, 'results.csv')
    box_loss, dfl_loss = extract_best_losses(results_path)

    return [name, mAP50, mAP5095, str(duration), model_size, infer_time, box_loss, dfl_loss]

def main():
    summary = [['Model', 'mAP50', 'mAP50-95', 'Training Time', 'Model Size', 'Inference Time', 'Best box_loss', 'Best dfl_loss']]
    for model in MODELS:
        try:
            result = run_training(model)
            summary.append(result)
        except Exception as e:
            print(f"❌ Failed training {model}: {e}")
            summary.append([model] + ['ERROR'] * 7)

    import csv
    with open(OUT_CSV, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerows(summary)

    print(f'\\n🎉 All training completed. Results saved to: {OUT_CSV}')

if __name__ == '__main__':
    main()