"""
Worker process for distributed transfer experiment execution.

Each worker:
1. Claims a pending task from the transfer queue
2. Runs the transfer experiment (either reward model or policy evaluation)
3. Logs evaluation metrics (regret, mean_reward) to the queue
4. Repeats until no pending tasks remain

Usage:
    python -m umfavi.experiments.transfer_worker --queue-dir tasks_transfer
"""

import json
import os
import socket
import sys
import traceback
import argparse
from typing import Optional

import wandb

from umfavi.experiments.file_queue import FileTaskQueue


def get_worker_id() -> str:
    """Generate a unique worker identifier."""
    hostname = socket.gethostname()
    pid = os.getpid()
    return f"{hostname}-{pid}"


class TransferQueueCallback:
    """
    Callback to log evaluation metrics to the queue during transfer experiments.
    
    For transfer experiments, we only log once (after evaluation completes),
    so this tracks the final metrics.
    """
    
    def __init__(self, queue: FileTaskQueue, experiment_id: int):
        self.queue = queue
        self.experiment_id = experiment_id
    
    def log_results(self, metrics: dict) -> None:
        """Log final evaluation metrics to the queue."""
        # Log at epoch 0 since transfer experiments don't have epochs
        self.queue.log_evaluation(self.experiment_id, epoch=0, metrics=metrics)


def run_transfer_worker(
    queue_dir: str,
    max_experiments: Optional[int] = None,
    dry_run: bool = False,
    verbose: bool = True,
) -> dict:
    """
    Run the worker loop to process transfer experiments.
    
    Args:
        queue_dir: Path to the transfer task queue directory.
        max_experiments: Maximum number of experiments to run (None for unlimited).
        dry_run: If True, claim tasks but don't actually run them.
        verbose: Print progress information.
        
    Returns:
        Dict with counts: {"completed": N, "failed": M}
    """
    # Import here to avoid circular imports
    from transfer import run_transfer
    
    queue = FileTaskQueue(queue_dir)
    worker_id = get_worker_id()
    
    completed = 0
    failed = 0
    
    if verbose:
        print(f"Transfer Worker {worker_id} starting...")
        status = queue.get_status_summary()
        print(f"  Queue status: {status}")
    
    while True:
        # Check if we've hit the limit
        if max_experiments is not None and (completed + failed) >= max_experiments:
            if verbose:
                print(f"Reached max experiments limit ({max_experiments})")
            break
        
        # Try to claim a task
        experiment = queue.claim_pending_task(worker_id)
        
        if experiment is None:
            if verbose:
                print(f"No pending tasks remaining. Worker shutting down.")
            break
        
        if verbose:
            print(f"\n{'='*60}")
            print(f"Claimed transfer experiment {experiment.id} (seed={experiment.seed})")
            print(f"Config hash: {experiment.config_hash}")
            print(f"{'='*60}")
        
        if dry_run:
            if verbose:
                print(f"  [DRY RUN] Would run transfer: {experiment.config}")
            queue.mark_completed(experiment.id)
            completed += 1
            continue
        
        try:
            # Extract config
            config = experiment.config
            
            # Build env_params from config keys with "env_params." prefix
            env_params = {}
            env_params_prefix = "env_params."
            for key, value in config.items():
                if key.startswith(env_params_prefix):
                    param_name = key[len(env_params_prefix):]
                    env_params[param_name] = value
            
            # Create args namespace for run_transfer
            class TransferArgs:
                pass
            
            args = TransferArgs()
            args.env_id = config.get("env_id")
            args.env_params = json.dumps(env_params) if env_params else "{}"
            args.fb_model_path = config.get("fb_model_path")
            args.optimal_policy_path = config.get("optimal_policy_path")
            args.mode = config.get("mode", "reward_model")
            args.num_samples = config.get("num_samples", 100)
            args.max_num_steps = config.get("max_num_steps", 1000)
            args.gamma = config.get("gamma", 0.99)
            args.seed = experiment.seed
            args.retrain_verbose = config.get("retrain_verbose", 0)
            args.no_progress_bar = config.get("no_progress_bar", True)
            args.retrain_reward_thresh = config.get("retrain_reward_thresh", None)
            
            # Model architecture settings
            args.hidden_sizes = config.get("hidden_sizes", [256, 256])
            args.reward_domain = config.get("reward_domain", "sa")
            
            # Feature transforms
            args.act_transform = config.get("act_transform", None)
            args.obs_transform = config.get("obs_transform", None)
            
            # Regret calculation settings
            args.use_avg_state_regret = config.get("use_avg_state_regret", False)
            
            # W&B settings (disabled by default for queue-based runs)
            args.log_wandb = config.get("log_wandb", False)
            args.wandb_project = config.get("wandb_project", "umfavi-transfer")
            args.wandb_entity = config.get("wandb_entity", None)
            args.wandb_name = config.get("wandb_name", None)
            
            if verbose:
                print(f"  env_id: {args.env_id}")
                print(f"  env_params: {args.env_params}")
                print(f"  fb_model_path: {args.fb_model_path}")
                print(f"  optimal_policy_path: {args.optimal_policy_path}")
                print(f"  mode: {args.mode}")
            
            # Run the transfer experiment
            results = run_transfer(args)
            
            # Create callback and log results
            callback = TransferQueueCallback(queue, experiment.id)
            callback.log_results(results)
            
            # Mark as completed
            queue.mark_completed(experiment.id)
            
            completed += 1
            if verbose:
                print(f"Transfer experiment {experiment.id} completed:")
                print(f"  Regret: {results['regret']:.4f}")
                mean_rew = results.get('mean_reward')
                print(f"  Mean reward: {mean_rew:.4f}" if mean_rew is not None else "  Mean reward: N/A")
                
        except Exception as e:
            error_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
            queue.mark_failed(experiment.id, error_msg)
            failed += 1
            
            # Clean up any active wandb run to avoid interference with next experiment
            if wandb.run is not None:
                wandb.finish(exit_code=1)
            
            if verbose:
                print(f"Transfer experiment {experiment.id} failed: {e}")
                traceback.print_exc()
    
    if verbose:
        print(f"\nTransfer Worker {worker_id} finished.")
        print(f"  Completed: {completed}")
        print(f"  Failed: {failed}")
    
    return {"completed": completed, "failed": failed}


def main():
    """CLI entry point for the transfer worker."""
    parser = argparse.ArgumentParser(
        description="Run a worker to process transfer experiments from the queue"
    )
    parser.add_argument(
        "--queue-dir",
        type=str,
        default="tasks_transfer",
        help="Path to the transfer task queue directory (default: tasks_transfer)"
    )
    parser.add_argument(
        "--max-experiments",
        type=int,
        default=None,
        help="Maximum number of experiments to run (default: unlimited)"
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Claim tasks but don't actually run them"
    )
    parser.add_argument(
        "--quiet",
        action="store_true",
        help="Suppress progress output"
    )
    
    args = parser.parse_args()
    
    result = run_transfer_worker(
        queue_dir=args.queue_dir,
        max_experiments=args.max_experiments,
        dry_run=args.dry_run,
        verbose=not args.quiet,
    )
    
    # Exit with error code if any experiments failed
    if result["failed"] > 0:
        sys.exit(1)


if __name__ == "__main__":
    main()

