import os
import subprocess
import argparse
import time

# --- Configuration ---
# Must match generate_datasets.py
WORK_DIR_BASE = "/data/aaa/OminiControl/mmrotate_mar20_workdir"
MMROTATE_ROOT = "/data/aaa/OminiControl/mmrotate"

# Target Classes (Filtered Subset)
CLASSES = ("A2", "A3", "A4", "A6", "A7", "A8", "A9", "A10", "A11", "A14", "A18")

# Models Configuration
MODELS = [
    {
        "name": "oriented_rcnn",
        "config": os.path.join(MMROTATE_ROOT, "configs/oriented_rcnn/oriented_rcnn_r50_fpn_fp16_1x_dota_le90.py"),
        "weight": os.path.join(MMROTATE_ROOT, "weights/oriented_rcnn_r50_fpn_fp16_1x_dota_le90-57c88621.pth")
    },
    {
        "name": "redet",
        "config": os.path.join(MMROTATE_ROOT, "configs/redet/redet_re50_refpn_1x_dota_ms_rr_le90.py"),
        "weight": os.path.join(MMROTATE_ROOT, "weights/redet_re50_fpn_1x_dota_ms_rr_le90-fc9217b5.pth")
    },
    {
        "name": "rotated_faster_rcnn",
        "config": os.path.join(MMROTATE_ROOT, "configs/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90.py"),
        "weight": os.path.join(MMROTATE_ROOT, "weights/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth")
    },
    {
        "name": "s2anet",
        "config": os.path.join(MMROTATE_ROOT, "configs/s2anet/s2anet_r50_fpn_fp16_1x_dota_le135.py"),
        "weight": os.path.join(MMROTATE_ROOT, "weights/s2anet_r50_fpn_fp16_1x_dota_le135-5cac515c.pth")
    }
]

# Experiment List
EXPERIMENTS = [
    "baseline",
    "batch_r1.0", "batch_r2.0", "batch_r3.0", "batch_r4.0", "batch_r5.0",
    "copypaste_r1.0", "copypaste_r2.0", "copypaste_r3.0", "copypaste_r4.0", "copypaste_r5.0",
    "solar_r1.0", "solar_r2.0", "solar_r3.0", "solar_r4.0", "solar_r5.0",
]
EXPERIMENTS = [
    "solar_r1.0",
]

def get_dataset_dir(exp_name):
    if exp_name == "baseline":
        suffix = "baseline"
    else:
        suffix = exp_name
    return f"{WORK_DIR_BASE}_{suffix}"

def train_single_model(model_idx, device):
    """
    Train a SINGLE model across ALL datasets sequentially.
    """
    model_info = MODELS[model_idx]
    print(f"\n{'='*80}")
    print(f"🚀 Starting Sequence for Model: {model_info['name']}")
    print(f"{'='*80}\n")
    
    for exp_name in EXPERIMENTS:
        dataset_dir = get_dataset_dir(exp_name)
        
        if not os.path.exists(dataset_dir):
            print(f"❌ Dataset not found: {dataset_dir}. Run generate_datasets.py first.")
            continue
            
        print(f"\n>>> Processing Dataset: {exp_name}")
        
        # Define Run Directory
        run_work_dir = os.path.join(dataset_dir, "runs", model_info['name'])
        
        # Check if already trained
        latest_pth = os.path.join(run_work_dir, "latest.pth")
        
        if os.path.exists(latest_pth):
             print(f"  -> Run {run_work_dir} exists. Skipping Training...")
             continue

        # Construct cfg_options
        abs_data = os.path.abspath(dataset_dir)
        classes_str = str(CLASSES)
        
        # Base override configs (apply to all models)
        override_cfgs = [
            "model.backbone.pretrained=None" # Disable backbone pretrained check for ALL models
        ]
        
        # Handle Num Classes for different architectures
        if model_info['name'] == 'redet':
            override_cfgs.extend([
                "model.roi_head.bbox_head.0.num_classes=11",
                "model.roi_head.bbox_head.1.num_classes=11"
            ])
        elif model_info['name'] == 's2anet':
            # S2ANet has fam_head and odm_head, NO bbox_head
            override_cfgs.extend([
                "model.fam_head.num_classes=11",
                "model.odm_head.num_classes=11"
            ])
        else:
            # Standard ROI-based models (Oriented RCNN, Faster RCNN)
            override_cfgs.append("model.roi_head.bbox_head.num_classes=11")

        cfg_options = [
            f"work_dir={run_work_dir}",
            f"load_from={model_info['weight']}",
            
            # Paths
            f"data.train.img_prefix={os.path.join(abs_data, 'train/images')}",
            f"data.train.ann_file={os.path.join(abs_data, 'train/labelTxt')}",
            f"data.train.classes={classes_str}",
            
            f"data.val.img_prefix={os.path.join(abs_data, 'val/images')}",
            f"data.val.ann_file={os.path.join(abs_data, 'val/labelTxt')}",
            f"data.val.classes={classes_str}",
            
            f"data.test.img_prefix={os.path.join(abs_data, 'test/images')}",
            f"data.test.ann_file={os.path.join(abs_data, 'test/labelTxt')}",
            f"data.test.classes={classes_str}",
            
            # Training Schedule
            "runner.max_epochs=100",
            "lr_config.step=[65, 90]", 
            
            # Batch Size
            "data.samples_per_gpu=16",
            "data.workers_per_gpu=4",
            
            # Evaluation & Checkpoint
            "evaluation.interval=1",
            "evaluation.metric=mAP",
            "evaluation.save_best=mAP",
            "checkpoint_config.interval=200", 
            
            # GPU Device
            f"gpu_ids=[{device}]"
        ] + override_cfgs
        
        # Construct Command
        cmd = [
            "python", os.path.join(MMROTATE_ROOT, "tools/train.py"),
            model_info['config'],
            "--cfg-options"
        ] + cfg_options
        
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = str(device)
        
        print(f"  -> Training {model_info['name']} on {exp_name}...")
        try:
            subprocess.run(cmd, env=env, check=True)
        except subprocess.CalledProcessError as e:
            print(f"  ❌ Error training {model_info['name']} on {exp_name}: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sequential MMRotate Training")
    parser.add_argument('--model_idx', type=int, default=0, choices=[0, 1, 2, 3], 
                        help='Index of model to train: 0=oriented_rcnn, 1=redet, 2=rotated_faster_rcnn, 3=s2anet')
    parser.add_argument('--device', type=int, default=0, help='GPU Device ID')
    
    args = parser.parse_args()
    
    train_single_model(args.model_idx, args.device)