#!/usr/bin/env python3
"""
WandB Sweep Wrapper for Football Multi-Agent RL Training
"""

import sys
import os
from pathlib import Path

# Add the training script directory to path
train_path = Path(__file__).parent / "../../train"
sys.path.insert(0, str(train_path.resolve()))

import wandb

def validate_architecture(n_embd, n_head):
    """Validate that n_embd is divisible by n_head"""
    return n_embd % n_head == 0

def run_training():
    # Initialize wandb if not already done (when called by sweep agent, this connects to existing run)
    if wandb.run is None:
        wandb.init()
    
    # Now we can access config
    config = wandb.config
    
    # Validate architecture
    if not validate_architecture(config.n_embd, config.n_head):
        print(f"Invalid architecture: n_embd={config.n_embd} not divisible by n_head={config.n_head}")
        wandb.run.summary["invalid_architecture"] = True
        return
    
    # Import the training module
    from train_football import main
    
    # Build arguments as a list (like command line args)
    args_list = [
        '--env_name', str(config.env_name),
        '--scenario_name', str(config.scenario_name),
        '--algorithm_name', str(config.algorithm_name),
        '--experiment_name', f"sweep_{wandb.run.id}",
        '--seed', str(config.seed),
        '--num_agents', str(config.num_agents),
        '--num_env_steps', str(config.num_env_steps),
        '--episode_length', str(config.episode_length),
        '--representation', str(config.representation),
        '--rewards', str(config.rewards),
        '--n_rollout_threads', str(config.n_rollout_threads),
        '--save_interval', str(config.save_interval),
        '--log_interval', str(config.log_interval),
        '--use_transformer_base_actor',
        '--hidden_size', str(config.n_embd),
        '--lr', str(config.lr),
        '--critic_lr', str(config.critic_lr),
        '--ppo_epoch', str(config.ppo_epoch),
        '--clip_param', str(config.clip_param),
        '--num_mini_batch', str(config.num_mini_batch),
        '--entropy_coef', str(config.entropy_coef),
        '--max_grad_norm', str(config.max_grad_norm),
        '--n_block', str(config.n_block),
        '--n_embd', str(config.n_embd),
        '--n_head', str(config.n_head),
        '--user_name', str(config.user_name),
        '--wandb_name', str(config.wandb_name),
    ]
    
    # Set CUDA device if needed
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    
    print(f"Starting training with sweep config:")
    print(f"  lr={config.lr}, critic_lr={config.critic_lr}")
    print(f"  entropy_coef={config.entropy_coef}, clip_param={config.clip_param}")
    print(f"  n_block={config.n_block}, n_embd={config.n_embd}, n_head={config.n_head}")
    
    # Call the main function directly
    try:
        main(args_list)
        print(f"Training completed successfully for run {wandb.run.id}")
    except Exception as e:
        print(f"Training failed with error: {e}")
        import traceback
        traceback.print_exc()
        raise

if __name__ == "__main__":
    run_training()