import argparse
import wandb

from typing import Optional, Callable, Any

from umfavi.utils.reproducibility import seed_everything
from umfavi.utils.logging import console_log_eval_metrics
from umfavi.types import FeedbackType
from umfavi.utils.argparser_utils import create_parser
from umfavi.utils.train_utils import (
    get_git_commit_hash,
    validate_args,
    setup_experiment,
    train_epoch,
    run_validation,
    run_final_evaluation,
    compute_total_val_loss,
    visualize_epoch,
    save_model_checkpoint,
    load_model_checkpoint,
    wandb_run,
)


def get_default_args() -> argparse.Namespace:
    """
    Get an argparse.Namespace with all default argument values.
    
    This is useful for programmatic usage where you want to start with
    defaults and only override specific values.
    
    Returns:
        argparse.Namespace with default values for all arguments.
    """
    parser = create_parser()
    return parser.parse_args([])


def run_experiment(
    args: argparse.Namespace,
    db_callback: Optional[Callable[[int, dict], None]] = None,
) -> dict[str, Any]:
    """
    Run a single experiment with the given configuration.
    
    This is the main entry point for programmatic experiment execution.
    It supports optional database callbacks for logging evaluations.
    
    Model saving is controlled via args:
        - args.model_save_dir: Directory to save models and policies (None to disable)
        - args.save_behavior: "best" (save only best model) or "all" (save at every eval)
    
    Args:
        args: Namespace containing all experiment configuration.
        db_callback: Optional callback function called at each evaluation epoch.
                    Signature: db_callback(epoch: int, metrics: dict) -> None
                       
    Returns:
        Dict containing:
            - "best_model_path": Path to saved best reward model (or None)
            - "best_policy_path": Path to saved best estimated policy trained on learned reward (or None)
            - "wandb_run_id": The wandb run ID (or None if not logging)
            - "final_metrics": Dict of final evaluation metrics
    """
    # Reproducibility
    seed_everything(args.seed)
    
    # Get git commit hash for reproducibility
    git_commit_hash = get_git_commit_hash()
    if git_commit_hash:
        print(f"Git commit hash: {git_commit_hash}")
    else:
        print("Warning: Could not determine git commit hash (not in a git repo or git not available)")
    
    # Initialize result tracking
    result = {
        "best_model_path": None,
        "best_policy_path": None,
        "best_epoch": None,
        "wandb_run_id": None,
        "final_metrics": {},
        "git_commit_hash": git_commit_hash
    }
    
    # Prepare feedback configuration
    feedback_config = {
        FeedbackType.PREFERENCE: args.n_pref_samples,
        FeedbackType.DEMONSTRATION: args.n_demo_samples,
        FeedbackType.RATING: args.n_rating_samples,
        FeedbackType.RANKING: args.n_ranking_samples,
        FeedbackType.STOP: args.n_stop_samples,
    }
    
    # Validate arguments
    active_feedback_types = validate_args(args, feedback_config)
    
    # Prepare wandb config
    wandb_config = vars(args).copy()
    if git_commit_hash:
        wandb_config["git_commit_hash"] = git_commit_hash
    
    with wandb_run(args, wandb_config) as run:
        if run is not None:
            result["wandb_run_id"] = run.id
        
        # Setup experiment
        components = setup_experiment(args, active_feedback_types)
        
        # Initialize training state
        dloader_iters = {k: iter(components.train_dataloaders[k]) for k in active_feedback_types}
        best_val_loss = float("inf")
        
        print(f"Starting training for {args.num_epochs} epochs")
        
        # Training loop
        for epoch in range(args.num_epochs):
            global_step = train_epoch(components, args, epoch, dloader_iters)
            relative_step = (global_step + 1) / components.steps_per_epoch
            
            # Validation
            should_eval = (
                args.val_every_n_epochs
                and (epoch + 1) % args.val_every_n_epochs == 0
                and (not args.skip_first_val_epoch or epoch > 0)
            )
            
            if should_eval:
                eval_metrics = run_validation(components, args)
                
                # Compute total validation loss for model selection
                val_loss = compute_total_val_loss(eval_metrics, args.kl_weight, args.td_error_weight)
                eval_metrics["eval/total_val_loss"] = val_loss
                
                # Logging
                if args.log_wandb:
                    eval_metrics |= {"epoch": epoch, "relative_step": relative_step}
                    wandb.log(eval_metrics, step=global_step)
                
                console_log_eval_metrics(eval_metrics)
                
                if db_callback is not None:
                    db_callback(epoch, eval_metrics)
                
                # Model saving (based on validation loss, not regret)
                best_val_loss = save_model_checkpoint(
                    args, components.fb_model, components.optimizer,
                    epoch, val_loss, eval_metrics, best_val_loss, result
                )
                
                components.fb_model.train()
            
            # Visualization
            should_visualize = args.vis_every_n_epochs and (epoch + 1) % args.vis_every_n_epochs == 0
            if should_visualize:
                visualize_epoch(args, components.env, components.fb_model, epoch, global_step)
        
        print(f"\n{'='*60}")
        print("Training completed!")
        print(f"{'='*60}\n")
        
        # Final evaluation: load best model and compute regret
        if result["best_model_path"] is not None:
            print("Loading best model for final evaluation...")
            load_model_checkpoint(result["best_model_path"], components.fb_model)
        else:
            print("No best model saved - using final model for evaluation...")
        
        print("Running final evaluation (computing regret)...")
        final_metrics, _ = run_final_evaluation(components, args)
        
        # Merge validation metrics from best checkpoint with regret metrics
        result["final_metrics"] = final_metrics.copy()
        
        # Log final metrics to wandb
        if args.log_wandb:
            final_log = {f"final/{k.replace('eval/', '')}": v for k, v in final_metrics.items()}
            wandb.log(final_log)
        
        console_log_eval_metrics(final_metrics)
        
        print(f"\n{'='*60}")
        print("Final evaluation completed!")
        print(f"  Regret: {final_metrics.get('eval/regret', 'N/A')}")
        print(f"  Mean Reward: {final_metrics.get('eval/mean_rew', 'N/A')}")
        print(f"{'='*60}\n")
    
    return result


def main(args: argparse.Namespace):
    """
    Main entry point for CLI usage.
    
    This is a thin wrapper around run_experiment for backwards compatibility.
    """
    run_experiment(args)


if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    main(args)
