import os
# Set environment variables
# Ensure these are set before importing any libraries that might use NCCL (e.g., torch.distributed)
os.environ['NCCL_NVLS_ENABLE'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' 
os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' # for flashinfer-python 
# Verify environment variables are set
print(f"NCCL_NVLS_ENABLE is set to: {os.getenv('NCCL_NVLS_ENABLE')}")

import argparse
import re
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM
from src.ref_trainer import RefGuidedVITrainer
from src.ref_config import RefGuidedVIConfig
from src.utils import get_rank, is_main_process, print_project_structure
from src.reward_func import think_format_reward_func, answer_accuracy_reward_func, no_reference_answer_leakage_reward_func, valid_reasoning_reward_func


def none_or_float(value: str):
    """argparse type: allow the literal 'None' (case-insensitive) to map to Python None, else parse as float."""
    if value is None:
        return None
    s = str(value).strip()
    if s.lower() in {"none", "null"}:
        return None
    return float(s)
def parse_args():
    """
    Parse command-line arguments.
    """
    parser = argparse.ArgumentParser(description="RefGuidedVI Training Script")
    
    # Model settings
    parser.add_argument("--model-path", type=str, default="./modelscope", 
                       help="Base path for models")
    parser.add_argument("--policy-model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct", 
                       help="Policy model name")
    
    # Dataset settings
    parser.add_argument("--dataset-path", type=str, default="./datasets",
                       help="Base path for datasets")
    parser.add_argument("--train-dataset-file", type=str, default="rlpr/test/gpqa_diamond_Avg4.parquet",
                       help="Training dataset file name")
    parser.add_argument("--dataset-split", type=str, default="train",
                       help="Dataset split to use")
    
    # Evaluation dataset settings
    parser.add_argument("--eval-dataset-files", type=str, default=None,
                       help="Comma-separated list of evaluation dataset files (parquet format)")
    parser.add_argument("--eval-dataset-split", type=str, default="train",
                       help="Dataset split to use for evaluation datasets")
    
    # Training settings
    parser.add_argument("--output-dir", type=str, default=None,
                       help="Output directory for training logs and checkpoints")
    parser.add_argument("--num-train-epochs", type=int, default=5,
                       help="Number of training epochs")
    parser.add_argument("--warmup-ratio", type=float, default=0.0,
                       help="Warmup ratio for learning rate scheduler")
    parser.add_argument("--learning-rate", type=float, default=1e-6,
                          help="Learning rate for the optimizer")
    parser.add_argument("--lr-scheduler-type", type=str, default="constant",
                          help="Learning rate scheduler type")
    parser.add_argument("--torch-empty-cache-steps", type=int, default=-1,
                          help="Number of steps to wait before empty_cache. If left unset or set to None, cache will not be emptied.")

    # Batching settings
    parser.add_argument("--per-device-eval-batch-size", type=int, default=4,
                       help="Evaluation batch size per device")
    parser.add_argument("--per-device-train-batch-size", type=int, default=4,
                       help="Training batch size per device")
    parser.add_argument("--gradient-accumulation-steps", type=int, default=1,
                       help="Number of gradient accumulation steps")
    parser.add_argument("--steps-per-generation", type=int, default=None,
                       help="Maximum number of steps per generation")
    parser.add_argument("--generation-batch-size", type=int, default=None,
                       help="Maximum number of generations per batch")
    parser.add_argument("--num-iterations", type=int, default=1,
                       help="Number of iterations per batch")

    # Sampling settings
    parser.add_argument("--max-prompt-length", type=int, default=1024,
                       help="Maximum prompt length")
    parser.add_argument("--max-completion-length", type=int, default=3072,
                       help="Maximum completion length")
    parser.add_argument("--num-generations", type=int, default=8,
                       help="Number of generations per prompt")
    parser.add_argument("--mask-truncated-completions", action="store_true", default=False,
                       help="Mask truncated completions (set to 0) in the loss computation")
    
    # Compute settings
    parser.add_argument("--bf16", action="store_true", default=False,
                       help="Use bf16 precision")
    parser.add_argument("--gradient-checkpointing", action="store_true", default=False,
                       help="Enable gradient checkpointing")
    parser.add_argument("--ddp-find-unused-parameters", action="store_true", default=False,
                       help="Find unused parameters in DDP")
    parser.add_argument("--use-liger-kernel", action="store_true", default=False,
                       help="Enable liger kernel")
    parser.add_argument("--torch-compile", action="store_true", default=False,
                       help="Enable torch compile")
    
    # VLLM settings
    parser.add_argument("--use-vllm", action="store_true", default=False,
                       help="Use VLLM for inference")
    parser.add_argument("--vllm-mode", type=str, default="colocate",
                       help="VLLM mode")
    parser.add_argument("--vllm-gpu-memory-utilization", type=float, default=0.3,
                       help="VLLM GPU memory utilization")
    parser.add_argument("--vllm-tensor-parallel-size", type=int, default=4,
                       help="VLLM tensor parallel size")
    
    # Generation settings
    parser.add_argument("--temperature", type=float, default=1.2,
                       help="Generation temperature")
    parser.add_argument("--top-p", type=float, default=1.0,
                       help="Top-p sampling parameter")
    parser.add_argument("--top-k", type=int, default=-1,
                       help="Top-k sampling parameter")
    
    # Reward settings
    parser.add_argument("--reward-weights", type=float, nargs='*', default=[1.0, 1.0, 1.0, 1.0],
                       help="Weights for reward functions")
    parser.add_argument("--prob-reward-weight",type=float, default=0.8,
                       help="Weights for reference answer probability reward function")
    parser.add_argument("--min-prob-reward-ratio", type=float, default=1.0,
                       help="Minimum value for the probability reward ratio")
    parser.add_argument("--max-prob-reward-ratio", type=float, default=10.0,
                       help="Maximum value for the probability reward ratio")
    parser.add_argument("--prob-model", type=str, default="self",
                       help="Model to use for computing the probability reward (self or ref)")
    parser.add_argument("--format-wrong-reward", type=float, default=-1.0,
                       help="Reward for format errors in the response (default: -1.0)")
    parser.add_argument("--reference-leakage-reward", type=float, default=-0.5,
                       help="Reward for reference leakage in the response (default: -0.5)")
    parser.add_argument("--invalid-reasoning-in-response-reward", type=float, default=-0.5,
                       help="Reward for invalid reasoning in the response (default: -0.5)")
    parser.add_argument("--prob-reward-baseline", type=str, default="naive_group_mean",
                          help="Baseline method for q(z|x,y) reward baseline")
    parser.add_argument("--answer-prefix", type=str, default="simple_prefix",
                          help="The prefix before the answer, can be 'simple_prefix' or 'none'.")
    # Loss settings
    parser.add_argument("--epsilon",type=float, default=0.2, # 3e-4 for GSPO
                        help="Epsilon value for clipping.")
    parser.add_argument("--epsilon-high",type=float, default=0.28, # 4e-4 for GSPO
                        help="Upper-bound epsilon value for clipping.")
    parser.add_argument("--importance-sampling-level",type=str, default="token",
                        help="Controls whether importance sampling ratios are computed at the 'token' or 'sequence' level.")
    parser.add_argument("--beta",type=float, default=0.0,
                        help="Weights for ref kl loss")
    parser.add_argument("--z-kl-beta",type=float, default=1e-6,
                       help="Weights for z kl")
    parser.add_argument("--sft-beta",type=float, default=1e0,
                       help="Weights for sft")
    parser.add_argument("--kl-estimator", type=str, default="k3",
                       help="KL divergence estimator (k1 or k3)")
    parser.add_argument("--min-r", type=float, default=1e-2,
                       help="The minimum value for the r in logr in kl computing.")
    parser.add_argument("--max-r", type=float, default=100.0,
                       help="The maximum value for the r in logr in kl computing.")
    parser.add_argument("--sync-ref-model", action='store_true',
                   help="Whether to sync the reference model with the policy model.")
    parser.add_argument("--ref-model-mixup-alpha", type=float, default=0.6,
                       help="Mixup alpha for the reference model.")
    parser.add_argument("--ref-model-sync-steps", type=int, default=512,
                       help="how frequently the current policy is synchronized with the reference policy.")
    parser.add_argument("--top-entropy-quantile", type=float, default=1.0,
                       help="The quantile of the top entropy to use for the reference model.")
    parser.add_argument("--loss-type", type=str, default="grpo",
                        help="Type of loss to use (grpo, bnpo, or dr_grpo)")
    parser.add_argument("--z-kl-sample-weight", type=str, default="clipped_prob_gain",
                       help="z kl sample weight.")
    parser.add_argument("--z-kl-constraint-coef", type=none_or_float, default=0.5,
                       help="Coefficient for the z kl constraint (pass 'None' to disable constraint term).")
    parser.add_argument("--z-kl-learning-coef", type=none_or_float, default=0.5,
                       help="Coefficient for the z kl learning (pass 'None' to disable learning term).")
    parser.add_argument("--p-grpo-loss-coef", type=float, default=1.0,
                       help="Coefficient for the P grpo loss.")
    parser.add_argument("--q-grpo-loss-coef", type=float, default=1.0,
                       help="Coefficient for the Q grpo loss.")
    # Evaluation strategy settings
    parser.add_argument("--bf16-full-eval", action="store_true", default=False,
                       help="Enable bf16 full evaluation")
    parser.add_argument("--eval-strategy", type=str, default="steps",
                       help="Evaluation strategy (steps, epoch, no)")
    parser.add_argument("--eval-steps", type=int, default=50,
                       help="Number of steps between evaluations")
    parser.add_argument("--save-strategy", type=str, default="steps",
                       help="Save strategy (steps, epoch, no)")
    parser.add_argument("--save-steps", type=int, default=50,
                       help="Number of steps between saves")
    
    # Logging settings
    parser.add_argument("--logging-steps", type=int, default=1,
                       help="Logging frequency in steps")
    parser.add_argument("--logging-strategy", type=str, default="steps",
                       help="Logging strategy")
    parser.add_argument("--report-to", type=str, default="wandb",
                       help="Reporting destination")
    parser.add_argument("--log-completions", action="store_true", default=False,
                       help="Log completions to the reporting destination")
    parser.add_argument("--num-completions-to-print", type=int, default=1,
                       help="Number of completions to print")
    parser.add_argument("--run-name", type=str, default="undefined",
                       help="Run name for tracking")

    return parser.parse_args()


def extract_dataset_name(file_path):
    """
    Extract dataset name from a file path.

    Examples:
    "rlpr/test/gpqa_diamond_Avg4.parquet" -> "gpqa_diamond_Avg4"
    "math_dataset.parquet" -> "math_dataset"
    """
    filename = os.path.basename(file_path)
    name_without_ext = os.path.splitext(filename)[0]
    return name_without_ext


def generate_dynamic_output_dir(args):
    """
    Generate a hierarchical output directory path based on hyperparameters.

    Directory structure:
    logs/
    ├── model_name/
    │   ├── dataset_name/
    │   │   ├── training_config/
    │   │   │   ├── generation_config/
    │   │   │   │   ├── loss_config/
    │   │   │   │   │   └── reward_config/

    Args:
        args: Parsed command-line arguments.

    Returns:
        str: The hierarchical output directory path.
    """
    # Base directory
    if args.output_dir is None:
        base_dir = "logs"
    else:
        base_dir = args.output_dir
    
    # Level 1: model name
    model_name = args.policy_model.split('/')[-1]
    
    # Level 2: dataset name
    dataset_name = extract_dataset_name(args.train_dataset_file)
    
    # Level 3: training config
    training_config = "_".join([
        f"lr{args.learning_rate}",
        f"sch{args.lr_scheduler_type}",
        f"epoch{args.num_train_epochs}",
        f"warmup{args.warmup_ratio}",
        f"trainBS{args.per_device_train_batch_size}",
        f"gradAccum{args.gradient_accumulation_steps}",
    ])
    
    # Level 4: generation config
    generation_config = "_".join([
        f"maxPromptLen{args.max_prompt_length}",
        f"maxCompLen{args.max_completion_length}",
        f"NumGen{args.num_generations}",
        f"temp{args.temperature}",
        f"topP{args.top_p}",
        f"topK{args.top_k}",
    ])
    
    # Level 5: loss config
    loss_config = "_".join([
        f"loss{args.loss_type}",
        f"eps{args.epsilon}",
        f"epsH{args.epsilon_high}",
        f"refklbeta{args.beta}",
        f"pGrpoCoef{args.p_grpo_loss_coef}",
        f"qGrpoCoef{args.q_grpo_loss_coef}",
        f"zklbeta{args.z_kl_beta}",
        f"zklCon{args.z_kl_constraint_coef}",
        f"zklLearn{args.z_kl_learning_coef}",
        f"zklSampleWeight{args.z_kl_sample_weight}",
        f"sft{args.sft_beta}",
        f"minR{args.min_r}",
        f"maxR{args.max_r}",
        f"klEst{args.kl_estimator}",
    ])
    
    # Level 6: reward config
    reward_config = "_".join([
        f"probW{args.prob_reward_weight}",
        f"minProbR{args.min_prob_reward_ratio}",
        f"maxProbR{args.max_prob_reward_ratio}",
        f"probMdl{args.prob_model}",
        f"formatWrong{args.format_wrong_reward}",
        f"probRewardBaseline{args.prob_reward_baseline}",
        f"answerPrefix{args.answer_prefix}",
        # f"refLeak{args.reference_leakage_reward}",
        # f"invalidReasoning{args.invalid_reasoning_in_response_reward}",
    ])
    
    # Build the full hierarchical path
    output_dir = os.path.join(
        base_dir,
        "RefVI",
        model_name,
        dataset_name,
        training_config,
        generation_config,
        loss_config,
        reward_config
    )
    
    return output_dir


def load_eval_datasets(args):
    """
    Load multiple evaluation datasets.

    Returns:
        dict: Keys are dataset names, values are dataset objects.
        Example: {
            "gpqa_diamond_Avg4": Dataset(...),
            "math_dataset": Dataset(...),
            "science_qa": Dataset(...)
        }
    """
    if not args.eval_dataset_files:
        return None
    
    # Parse the comma-separated file list
    eval_file_list = [f.strip() for f in args.eval_dataset_files.split(',') if f.strip()]
    
    if not eval_file_list:
        return None
    
    eval_datasets = {}
    
    for eval_file in eval_file_list:
        # Extract dataset name from file path
        dataset_name = extract_dataset_name(eval_file)
        
        # Load dataset
        dataset_full_path = os.path.join(args.dataset_path, eval_file)
        dataset = load_dataset("json",
                             data_files=dataset_full_path,
                             split=args.eval_dataset_split)
        
        eval_datasets[dataset_name] = dataset
        
        if get_rank() == 0:
            print(f"✅ Loaded eval dataset: {dataset_name} ({len(dataset)} samples)")
    
    return eval_datasets


def main():
    """Main entry point."""
    # Parse command-line arguments
    args = parse_args()
    
    # If rank 0, print key paths
    if get_rank() == 0:
        print("=" * 50)
        print(f"Model path: {os.path.join(args.model_path, args.policy_model)}")
        print(f"Training dataset path: {os.path.join(args.dataset_path, args.train_dataset_file)}")
        print("=" * 50)
    
    # Load training dataset
    train_dataset_full_path = os.path.join(args.dataset_path, args.train_dataset_file)
    # eval_dataset_full_path = os.path.join(args.dataset_path, args.eval_dataset_files)
    data_files = {
        "train": train_dataset_full_path,
        # "test": eval_dataset_full_path
    }

    train_test_dataset = load_dataset("json", data_files=data_files)
    train_dataset = train_test_dataset["train"]

    
    eval_datasets = load_eval_datasets(args)
    
    if get_rank() == 0:
        print(f"✅ Loaded training dataset: {len(train_dataset)} samples")
        if eval_datasets:
            if isinstance(eval_datasets, dict):
                for dataset_name, dataset in eval_datasets.items():
                    print(f"✅ Loaded eval dataset: {dataset_name} - {len(dataset)} samples")
            else:
                print(f"✅ Loaded eval dataset: {len(eval_datasets)} samples")
        else:
            print("ℹ️  No evaluation dataset specified")
    
    # Set output directory (dynamically generated with hyperparameters)
    args.output_dir = generate_dynamic_output_dir(args)
    
    if get_rank() == 0:
        print(f"📁 Output dir: {args.output_dir}")
        print("=" * 50)
    
    # Create training config
    training_args = RefGuidedVIConfig(
        output_dir=args.output_dir,
        report_to=args.report_to,
        # Training settings
        num_train_epochs=args.num_train_epochs,
        warmup_ratio=args.warmup_ratio,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        torch_empty_cache_steps = args.torch_empty_cache_steps if args.torch_empty_cache_steps > 0 else None,
        # Batching settings
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        generation_batch_size=args.generation_batch_size,
        steps_per_generation=args.steps_per_generation,
        num_iterations=args.num_iterations,
        # Sampling settings
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        num_generations=args.num_generations,
        mask_truncated_completions=args.mask_truncated_completions,
        # Compute settings
        bf16=args.bf16,
        gradient_checkpointing=args.gradient_checkpointing,
        torch_compile=args.torch_compile,
        torch_compile_backend="inductor" if args.torch_compile else None,
        torch_compile_mode="default" if args.torch_compile else None,
        use_liger_kernel=args.use_liger_kernel,
        ddp_find_unused_parameters=args.ddp_find_unused_parameters,
        # VLLM settings
        use_vllm=args.use_vllm,
        vllm_mode=args.vllm_mode,
        vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
        vllm_tensor_parallel_size=args.vllm_tensor_parallel_size,
        # Generation settings
        generation_kwargs={
            "temperature": args.temperature,
            "top_p": args.top_p,
            "top_k": args.top_k,
        },
        # Reward settings
        reward_weights=args.reward_weights,
        prob_reward_weight=args.prob_reward_weight,  # Weight for reference-answer probability reward
        min_prob_reward_ratio=args.min_prob_reward_ratio,  # Min ratio for probability reward
        max_prob_reward_ratio=args.max_prob_reward_ratio,  # Max ratio for probability reward
        prob_model=args.prob_model,  # Model used to compute probability reward
        format_wrong_reward=args.format_wrong_reward,  # Reward value for format errors
        prob_reward_baseline=args.prob_reward_baseline, # Baseline for probability reward
        answer_prefix=args.answer_prefix,  # Answer prefix
        # reference_leakage_reward=args.reference_leakage_reward,
        # invalid_reasoning_in_response_reward=args.invalid_reasoning_in_response_reward,
        # Loss settings
        epsilon=args.epsilon,  # Epsilon for clipping
        epsilon_high=args.epsilon_high,  # Upper epsilon for clipping
        beta=args.beta,  # Weight for reference KL loss
        z_kl_beta=args.z_kl_beta,  # Weight for z KL
        sft_beta=args.sft_beta,  # Weight for SFT
        z_kl_constraint_coef=args.z_kl_constraint_coef,  # Coefficient for z KL constraint
        z_kl_learning_coef=args.z_kl_learning_coef,
        z_kl_sample_weight=args.z_kl_sample_weight,  # z KL sample weighting method
        kl_estimator=args.kl_estimator,  # KL divergence estimator
        max_r=args.max_r,  # Max r for logr clipping
        min_r=args.min_r,  # Min r for logr clipping
        importance_sampling_level=args.importance_sampling_level,  # Importance sampling level
        sync_ref_model=args.sync_ref_model,  # Whether to sync the reference model
        ref_model_mixup_alpha=args.ref_model_mixup_alpha,  # Mixup alpha for reference model
        ref_model_sync_steps=args.ref_model_sync_steps,  # Sync frequency for reference model
        top_entropy_quantile=args.top_entropy_quantile,  # Top-entropy quantile for masking
        loss_type=args.loss_type,  # Loss type (grpo, bnpo, dr_grpo)
        p_grpo_loss_coef=args.p_grpo_loss_coef,  # Coefficient for p GRPO loss
        q_grpo_loss_coef=args.q_grpo_loss_coef,  # Coefficient for q GRPO loss
        # Evaluation config
        bf16_full_eval=args.bf16_full_eval,
        eval_strategy=args.eval_strategy,
        eval_steps=args.eval_steps if args.eval_strategy == "steps" else None,
        save_strategy=args.save_strategy,
        save_steps=args.save_steps if args.save_strategy == "steps" else None,
        # load_best_model_at_end=True if eval_datasets else False,
        # metric_for_best_model="eval_loss",
        # greater_is_better=False,
        # Logging
        logging_steps=args.logging_steps,
        logging_strategy=args.logging_strategy,
        run_name=args.run_name,  # Run name
        log_completions=args.log_completions,
        num_completions_to_print=args.num_completions_to_print,  # Number of completions to print
    )
    
    model_path = os.path.join(args.model_path, args.policy_model)

    trainer = RefGuidedVITrainer(
        model=model_path,
        reward_funcs=[think_format_reward_func, answer_accuracy_reward_func, no_reference_answer_leakage_reward_func, valid_reasoning_reward_func],
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_datasets,
    )
    
    # Start training
    trainer.train()


if __name__ == "__main__":
    main()
