import os
import subprocess
import argparse
import glob
import json
import sys

# --- Configuration (Must match train script) ---
WORK_DIR_BASE = "/data/aaa/OminiControl/mmrotate_mar20_workdir"
MMROTATE_ROOT = "/data/aaa/OminiControl/mmrotate"

# Target Classes
CLASSES = ("A2", "A3", "A4", "A6", "A7", "A8", "A9", "A10", "A11", "A14", "A18")

# Models Configuration
MODELS = [
    {
        "name": "oriented_rcnn",
        "config": os.path.join(MMROTATE_ROOT, "configs/oriented_rcnn/oriented_rcnn_r50_fpn_fp16_1x_dota_le90.py")
    },
    {
        "name": "redet",
        "config": os.path.join(MMROTATE_ROOT, "configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py")
    },
    {
        "name": "rotated_faster_rcnn",
        "config": os.path.join(MMROTATE_ROOT, "configs/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90.py")
    },
    {
        "name": "s2anet",
        "config": os.path.join(MMROTATE_ROOT, "configs/s2anet/s2anet_r50_fpn_fp16_1x_dota_le135.py")
    }
]

# Experiment List
EXPERIMENTS = [
    "baseline",
    "batch_r1.0", "batch_r2.0", "batch_r3.0", "batch_r4.0", "batch_r5.0",
    "copypaste_r1.0", "copypaste_r2.0", "copypaste_r3.0", "copypaste_r4.0", "copypaste_r5.0",
    "solar_r1.0", "solar_r2.0", "solar_r3.0", "solar_r4.0", "solar_r5.0",
]

def get_dataset_dir(exp_name):
    if exp_name == "baseline":
        suffix = "baseline"
    else:
        suffix = exp_name
    return f"{WORK_DIR_BASE}_{suffix}"

def find_best_checkpoint(run_dir):
    """Find the best mAP checkpoint or fallback to latest."""
    # Look for best_mAP_epoch_*.pth
    best_ckpts = glob.glob(os.path.join(run_dir, "best_mAP_epoch_*.pth"))
    
    if best_ckpts:
        # If multiple, sort by epoch number or mAP if possible. 
        # Usually only one exists if we don't keep history, but let's just pick the last one found or sort.
        # Filename format: best_mAP_epoch_10.pth
        # We can try to sort by modification time to get the latest 'best'.
        best_ckpts.sort(key=os.path.getmtime, reverse=True)
        return best_ckpts[0]
        
    # Fallback to latest.pth
    latest_ckpt = os.path.join(run_dir, "latest.pth")
    if os.path.exists(latest_ckpt):
        print(f"  ⚠️ 'best_mAP' checkpoint not found in {run_dir}, using 'latest.pth'")
        return latest_ckpt
        
    return None

def evaluate_single_run(model_info, exp_name, device):
    dataset_dir = get_dataset_dir(exp_name)
    run_work_dir = os.path.join(dataset_dir, "runs", model_info['name'])
    
    if not os.path.exists(run_work_dir):
        print(f"❌ Run directory not found: {run_work_dir}")
        return

    # 1. Find Checkpoint
    checkpoint_path = find_best_checkpoint(run_work_dir)
    if not checkpoint_path:
        print(f"❌ No checkpoint found in {run_work_dir}")
        return
        
    print(f"  -> Found checkpoint: {os.path.basename(checkpoint_path)}")
    
    # 2. Prepare Config Overrides
    abs_data = os.path.abspath(dataset_dir)
    classes_str = str(CLASSES)
    
    # Base override configs (apply to all models)
    override_cfgs = [
        "model.backbone.pretrained=None" 
    ]
    
    # Handle Num Classes
    if model_info['name'] == 'redet':
        override_cfgs.extend([
            "model.roi_head.bbox_head.0.num_classes=11",
            "model.roi_head.bbox_head.1.num_classes=11"
        ])
    elif model_info['name'] == 's2anet':
        override_cfgs.extend([
            "model.fam_head.num_classes=11",
            "model.odm_head.num_classes=11"
        ])
    else:
        override_cfgs.append("model.roi_head.bbox_head.num_classes=11")

    # Define paths
    pkl_result_path = os.path.join(run_work_dir, "test_results.pkl")
    json_result_path = os.path.join(run_work_dir, "test_metrics.json")
    
    # Check if metrics already exist (optional: skip if exists? User might want to re-eval)
    # Let's overwrite.

    cfg_options = [
        f"data.test.img_prefix={os.path.join(abs_data, 'test/images')}",
        f"data.test.ann_file={os.path.join(abs_data, 'test/labelTxt')}",
        f"data.test.classes={classes_str}",
    ] + override_cfgs
    
    # 3. Run Inference (Generate .pkl)
    cmd = [
        "python", os.path.join(MMROTATE_ROOT, "tools/test.py"),
        model_info['config'],
        checkpoint_path,
        "--out", pkl_result_path,
        "--cfg-options"
    ] + cfg_options
    
    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = str(device)
    
    print(f"  -> Running inference...")
    try:
        subprocess.run(cmd, env=env, check=True, stdout=subprocess.DEVNULL) # Suppress noisy output
    except subprocess.CalledProcessError as e:
        print(f"  ❌ Inference failed: {e}")
        return

    # 4. Calculate Metrics (using internal script)
    # We construct the same config overrides for the eval script
    
    eval_script = f"""
import mmcv
from mmcv import Config
from mmrotate.datasets import build_dataset
import json
import os
import sys

def main():
    try:
        cfg_path = "{model_info['config']}"
        cfg = Config.fromfile(cfg_path)
        
        # Merge overrides
        cfg.merge_from_dict(dict(
            data=dict(
                test=dict(
                    img_prefix="{os.path.join(abs_data, 'test/images')}",
                    ann_file="{os.path.join(abs_data, 'test/labelTxt')}",
                    classes={classes_str}
                )
            )
        ))
        
        # We don't strictly need to override model num_classes for dataset evaluation,
        # but the dataset builder might need correct class names which we provided.
        
        dataset = build_dataset(cfg.data.test)
        results = mmcv.load("{pkl_result_path}")
        
        print("Calculating mAP...")
        metric_dict = dataset.evaluate(results, metric='mAP')
        
        # Save to JSON
        safe_dict = {{}}
        for k, v in metric_dict.items():
            try:
                safe_dict[k] = float(v)
            except:
                safe_dict[k] = str(v)
                
        with open("{json_result_path}", 'w') as f:
            json.dump(safe_dict, f, indent=4)
            
    except Exception as e:
        print(f"Eval Script Error: {{e}}")
        sys.exit(1)

if __name__ == "__main__":
    main()
"""
    print(f"  -> Calculating metrics...")
    try:
        subprocess.run([sys.executable, "-c", eval_script], env=env, check=True)
        print(f"  ✅ Metrics saved to {json_result_path}")
        
        # Cleanup
        if os.path.exists(pkl_result_path):
            os.remove(pkl_result_path)
            
    except subprocess.CalledProcessError:
        print(f"  ❌ Evaluation script failed.")

def main():
    parser = argparse.ArgumentParser(description="Evaluate MMrotate Models")
    parser.add_argument('--model_idx', type=int, choices=[0, 1, 2, 3], help='Specific model index to run (optional)')
    parser.add_argument('--device', type=int, default=0, help='GPU Device ID')
    args = parser.parse_args()
    
    models_to_run = MODELS if args.model_idx is None else [MODELS[args.model_idx]]
    
    for model_info in models_to_run:
        print(f"\n{'='*60}")
        print(f"📊 Evaluating Model: {model_info['name']}")
        print(f"{'='*60}")
        
        for exp in EXPERIMENTS:
            print(f"\n>>> Experiment: {exp}")
            evaluate_single_run(model_info, exp, args.device)

if __name__ == "__main__":
    main()