#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Minimal BLT Grid Search

Usage: python blt_sweep.py -m graph -bs 64
"""

import sys
import subprocess
import argparse
import wandb
import concurrent.futures
import os

def parse_args():
    parser = argparse.ArgumentParser(description="Run grid search for BLT models")
    parser.add_argument("-m", "--model", choices=['all', 'main', 'graph', 'chemprop', 'feature', 'smiles'], required=True)
    parser.add_argument("-bs", "--batch_size", type=int, default=64, choices=[64, 128, 256])
    return parser.parse_args()

def run_sweep_agent(sweep_id, model, prop_type, seed):
    # Initialize wandb - config will be automatically populated with sweep parameters
    run = wandb.init(project="transductive_learning")
    
    # Get sweep parameters that were set by the agent
    lr = wandb.config.get("model.lr")
    epochs = wandb.config.get("model.n_epochs")
    
    # Get the candidates parameter (different name based on model)
    if model == 'main':
        candidates = wandb.config.get("model.mul_approx_train_deltas")
        candidate_param = f"approx{candidates}"
    else:
        candidates = wandb.config.get("model.num_candidates")
        candidate_param = f"cand{candidates}"
    
    # Update run name to include hyperparameters
    batch_size = wandb.config.get("model.batch_size")
    wandb.run.name = f"{model}_{prop_type}_bs{batch_size}_e{epochs}_{candidate_param}_lr{lr}_s{seed}"
    wandb.run.save()
    
    # Determine script name
    script = "blt_main.py" if model == 'main' else f"blt_{model}_main.py"
    
    # Base command
    cmd = [
        sys.executable, 
        script,
        '--wandb_log',
        f'--seed={seed}',
        f'--prop_type={prop_type}'
    ]
    
    # Add dataset_split_type if needed
    if prop_type.endswith("_x"):
        cmd.append('--dataset_split_type=scaffold')
    
    print(f"Running: {' '.join(cmd)}")
    subprocess.run(cmd)

def main():
    args = parse_args()
    model = args.model
    batch_size = args.batch_size
    
    wandb.login(anonymous="allow")
    
    prop_types = ['bace', 'bace_x', 'freesolv', 'freesolv_x', 'esol', 'esol_x', 'lipo', 'lipo_x']
    models = ['main', 'graph', 'chemprop', 'feature', 'smiles'] if model == 'all' else [model]
    seeds = [29, 42, 48, 72, 95]
    epochs = [500, 1000, 1500]
    candidates = [3, 5, 10]
    
    for m in models:
        for prop_type in prop_types:
            print(f"\n{'='*50}")
            print(f"Creating sweep for {m} model with {prop_type} property")
            print(f"{'='*50}\n")
            
            # Set learning rates based on model (don't modify this part)
            learning_rates = [1e-5, 1e-6] if m == 'smiles' else [1e-4, 1e-5]
            
            # Correctly set parameter names based on model type
            candidate_param_name = "model.mul_approx_train_deltas" if m == 'main' else "model.num_candidates"
            
            # Sweep configuration
            sweep_config = {
                'method': 'grid',
                'name': f"{m}_{prop_type}_{batch_size}",
                'metric': {'name': 'eval', 'goal': 'minimize'},
                'parameters': {
                    'model.lr': {'values': learning_rates},
                    'model.batch_size': {'values': [batch_size]},
                    'model.n_epochs': {'values': epochs},
                    'seed': {'values': seeds},
                    candidate_param_name: {'values': candidates}
                }
            }
            
            sweep_id = wandb.sweep(sweep_config, project="transductive_learning")
            print(f"Sweep created: {sweep_id}")
            
            # Run agents in parallel
            def agent_wrapper(seed):
                wandb.agent(sweep_id, function=lambda: run_sweep_agent(sweep_id, m, prop_type, seed), count=1)
            
            with concurrent.futures.ThreadPoolExecutor(max_workers=len(seeds)) as executor:
                futures = [executor.submit(agent_wrapper, seed) for seed in seeds]
                concurrent.futures.wait(futures)
            
            print(f"Sweep completed for {m}-{prop_type}!")

if __name__ == "__main__":
    main()
