import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForSequenceClassification
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import DPOConfig, DPOTrainer, SFTConfig, SFTTrainer, PPOConfig, PPOTrainer, GRPOConfig, GRPOTrainer
from .custom_rewards import LLMRewardFunction, RewardModelFunction, CustomRewardModel
from .custom_ppo_trainer import CustomPPOTrainer
from .custom_grpo_trainer import CustomGRPOTrainer
# from .wandb_callback import SampleGenerationCallback, RewardMetricsCallback
from typing import Optional, Dict, Any

try:
    from .constants import GRPO_SAVE_CHECKPOINTS_DICT, DATASET_NAMES_DICT
except ImportError:
    from constants import GRPO_SAVE_CHECKPOINTS_DICT, DATASET_NAMES_DICT

# Create reward function wrapper for GRPO
# def grpo_reward_function(completions, **kwargs):
#     """
#     Wrapper function for GRPO trainer that calls the appropriate reward function.
    
#     Args:
#         completions: List of completion dictionaries with 'content' key
    
#     Returns:
#         List of float rewards
#     """
#     # Extract queries and responses from completions
#     # GRPO provides completions as list of lists of dicts with 'role' and 'content'
#     responses = []
#     queries = []
    
#     for completion in completions:
#         # Each completion is a list of messages
#         # Find the last user message as query and assistant message as response
#         query = ""
#         response = ""
        
#         for msg in completion:
#             if msg.get('role') == 'user':
#                 query = msg.get('content', '')
#             elif msg.get('role') == 'assistant':
#                 response = msg.get('content', '')
        
#         # If no proper structure, treat the whole as response
#         if not response and completion:
#             response = completion[0].get('content', '') if isinstance(completion[0], dict) else str(completion[0])
        
#         queries.append(query if query else "")  # Use empty string if no query found
#         responses.append(response)
    
#     # Compute rewards using the initialized reward function
#     rewards_tensor = reward_function_instance.compute_reward(queries, responses)
    
#     # Convert to list of floats
#     return rewards_tensor.cpu().tolist()

def supports_flash_attention(device_id):
    """Check if a GPU supports FlashAttention."""
    if not torch.cuda.is_available():
        return False
    major, minor = torch.cuda.get_device_capability(device_id)
    
    # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0

    return is_sm8x or is_sm90

print('Flash Attention Supported:', supports_flash_attention(0))

def init_trainer(args, tokenizer, train_dataset, val_dataset):
    output_dir = os.path.join(args.root_save_dir, args.run_name)
    
    # Convert string dtype to torch dtype
    torch_dtype_map = {
        'bfloat16': torch.bfloat16,
        'float16': torch.float16,
        'float32': torch.float32
    }
    compute_dtype = torch_dtype_map.get(args.bnb_4bit_compute_dtype, torch.bfloat16)
    model_dtype = torch_dtype_map.get(args.torch_dtype, torch.bfloat16)
    
    if args.method == 'dpo':
        # # BitsAndBytesConfig int-4 config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype
        )

        # # Load model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            device_map=args.device_map,
            use_cache=args.use_cache,
            attn_implementation=args.attn_implementation,
            torch_dtype=model_dtype,
            quantization_config=bnb_config
        )

        prompt_length = args.prompt_length
        max_seq_length = args.max_seq_length

        # # LoRA config based on QLoRA paper & Sebastian Raschka experiment
        peft_config = LoraConfig(
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout,
                r=args.lora_r,
                bias=args.lora_bias,
                target_modules=args.lora_target_modules,
                task_type=args.lora_task_type,
        )

        training_args = DPOConfig(
            output_dir=output_dir,               # directory to save and repository id
            num_train_epochs=args.num_train_epochs,                     # number of training epochs
            per_device_train_batch_size=args.per_device_train_batch_size,         # batch size per device during training
            per_device_eval_batch_size=args.per_device_eval_batch_size,           # batch size for evaluation
            gradient_accumulation_steps=args.gradient_accumulation_steps,          # number of steps before performing a backward/update pass
            gradient_checkpointing=args.gradient_checkpointing,            # use gradient checkpointing to save memory
            optim=args.optim,              # use fused adamw optimizer
            learning_rate=args.learning_rate,                     # 10x higher LR than QLoRA paper
            max_grad_norm=args.max_grad_norm,                      # max gradient norm based on QLoRA paper
            warmup_ratio=args.warmup_ratio,                       # warmup ratio based on QLoRA paper
            lr_scheduler_type=args.lr_scheduler_type,             # use cosine learning rate scheduler
            logging_strategy=args.logging_strategy,
            logging_steps=args.logging_steps,                       # log every 5 steps
            save_strategy=args.save_strategy,
            save_steps=args.save_steps,                         # when to save checkpoint
            save_total_limit=args.save_total_limit,                     # limit the total amount of checkpoints
            eval_strategy=args.eval_strategy,            # evaluate every 1000 steps
            eval_steps=args.eval_steps,                         # when to evaluate
            bf16=args.bf16,                              # use bfloat16 precision
            tf32=args.tf32,                              # use tf32 precision
            beta=args.beta,
            loss_type=args.loss_type,
            max_length=max_seq_length,
            max_prompt_length=prompt_length,
            push_to_hub=args.push_to_hub,                      # push model to hub               
            report_to=args.report_to,                      # report metrics to tensorboard
        )

        trainer = DPOTrainer(
            model,
            ref_model=None, # set to none since we use peft
            peft_config=peft_config,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            processing_class=tokenizer, # tokenizer kwarg got renamed to processing_class in trl v0.16.0
        )
    elif args.method == 'sft':
        # # BitsAndBytesConfig int-4 config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype
        )
        
        # # Load model and tokenizer
        model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            device_map=args.device_map,
            use_cache=args.use_cache,
            attn_implementation=args.attn_implementation,
            torch_dtype=model_dtype,
            quantization_config=bnb_config
        )
        model.config.pad_token_id = tokenizer.pad_token_id

        # # LoRA config based on QLoRA paper & Sebastian Raschka experiment
        peft_config = LoraConfig(
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout,
                r=args.lora_r,
                bias=args.lora_bias,
                target_modules=args.lora_target_modules,
                task_type=args.lora_task_type,
        )

        training_args = SFTConfig(
            output_dir=output_dir,
            num_train_epochs=args.num_train_epochs,                     # number of training epochs
            per_device_train_batch_size=args.per_device_train_batch_size,         # batch size per device during training
            per_device_eval_batch_size=args.per_device_eval_batch_size,           # batch size for evaluation
            gradient_accumulation_steps=args.gradient_accumulation_steps,          # number of steps before performing a backward/update pass
            gradient_checkpointing=args.gradient_checkpointing,            # use gradient checkpointing to save memory
            save_strategy=args.save_strategy,
            save_steps=args.save_steps,                         # when to save checkpoint
            save_total_limit=args.save_total_limit,                     # limit the total amount of checkpoints
            eval_strategy=args.eval_strategy,            # evaluate every 1000 steps
            eval_steps=args.eval_steps,                         # when to evaluate
            logging_strategy=args.logging_strategy,
            logging_steps=args.logging_steps,
            learning_rate=args.learning_rate,
            dataset_text_field=args.dataset_text_field,
            completion_only_loss=args.completion_only_loss,
            bf16=args.bf16,                              # use bfloat16 precision
            tf32=args.tf32,                              # use tf32 precision
            max_length=args.max_length,
            push_to_hub=args.push_to_hub,                      # push model to hub               
            report_to=args.report_to,                      # report metrics to tensorboard
        )

        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            processing_class=tokenizer,
            peft_config=peft_config,
        )
    elif args.method == 'cai':
        pass
    elif args.method == 'reward_train':
        pass
    elif args.method == 'ppo':
        # BitsAndBytesConfig int-4 config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype
        )
        
        # Load policy model (the model being trained)
        policy_model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            device_map=args.device_map,
            use_cache=False,  # Important for PPO
            attn_implementation=args.attn_implementation,
            torch_dtype=model_dtype,
            quantization_config=bnb_config
        )
        
        # Prepare model for k-bit training
        policy_model = prepare_model_for_kbit_training(policy_model)
        
        # Convert all non-quantized layers to the correct dtype
        for name, module in policy_model.named_modules():
            # Convert linear layers that aren't quantized
            if isinstance(module, torch.nn.Linear):
                if hasattr(module, 'weight') and module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
            # Convert embeddings
            elif isinstance(module, torch.nn.Embedding):
                if module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
            # Convert layer norms (including RMSNorm used in Llama models)
            elif 'norm' in module.__class__.__name__.lower():
                if hasattr(module, 'weight') and module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
        
        # LoRA config for policy model
        peft_config = LoraConfig(
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            bias=args.lora_bias,
            # target_modules=args.lora_target_modules,
            task_type=args.lora_task_type,
        )
        
        # Load reference model (frozen copy of the policy model)
        ref_model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            device_map=args.device_map,
            use_cache=False,
            attn_implementation=args.attn_implementation,
            torch_dtype=model_dtype,
            quantization_config=bnb_config
        )
        
        # Prepare reference model for k-bit training
        ref_model = prepare_model_for_kbit_training(ref_model)
        
        # Convert all non-quantized layers to the correct dtype
        for name, module in ref_model.named_modules():
            # Convert linear layers that aren't quantized
            if isinstance(module, torch.nn.Linear):
                if hasattr(module, 'weight') and module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
            # Convert embeddings
            elif isinstance(module, torch.nn.Embedding):
                if module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
            # Convert layer norms (including RMSNorm used in Llama models)
            elif 'norm' in module.__class__.__name__.lower():
                if hasattr(module, 'weight') and module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
        
        # Initialize custom reward function with flexible combiner
        combiner_kwargs = {}
        
        # Add combiner-specific kwargs if provided
        if hasattr(args, 'mlp_hidden_sizes'):
            combiner_kwargs['hidden_sizes'] = args.mlp_hidden_sizes
        if hasattr(args, 'mlp_dropout_rate'):
            combiner_kwargs['dropout_rate'] = args.mlp_dropout_rate
        if hasattr(args, 'gb_n_estimators'):
            combiner_kwargs['n_estimators'] = args.gb_n_estimators
        if hasattr(args, 'gb_max_depth'):
            combiner_kwargs['max_depth'] = args.gb_max_depth
        
        llm_reward_function = LLMRewardFunction(
            model_name=getattr(args, 'reward_model_name'),
            use_api=getattr(args, 'use_api', True),
            combiner_type=getattr(args, 'reward_combiner_type', 'linear'),
            manual_weights=getattr(args, 'reward_manual_weights', None),
            manual_bias=getattr(args, 'reward_manual_bias', 0.0),
            device=args.device_map,
            max_length=getattr(args, 'reward_max_length', 512),
            objective_names=getattr(args, 'reward_objectives', None),  # Pass objectives if specified
            reward_combiner=None,
            init_args=args,
            max_concurrent=getattr(args, 'max_concurrent', 50),
            **combiner_kwargs
        )
        
        # Create TRL-compatible reward model wrapper
        reward_model = CustomRewardModel(
            llm_reward_function=llm_reward_function,
            tokenizer=tokenizer,
            query_response_separator=getattr(args, 'query_response_separator', 'Assistant:')
        )

        # PPO training configuration
        training_args = PPOConfig(
            output_dir=output_dir,
            num_train_epochs=args.num_train_epochs,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            gradient_checkpointing=args.gradient_checkpointing,
            optim=args.optim,
            learning_rate=args.learning_rate,
            max_grad_norm=args.max_grad_norm,
            warmup_ratio=args.warmup_ratio,
            lr_scheduler_type=args.lr_scheduler_type,
            logging_strategy=args.logging_strategy,
            logging_steps=args.logging_steps,
            save_strategy=args.save_strategy,
            save_steps=args.save_steps,
            save_total_limit=args.save_total_limit,
            eval_strategy=args.eval_strategy,
            eval_steps=args.eval_steps,
            bf16=args.bf16,
            tf32=args.tf32,
            push_to_hub=args.push_to_hub,
            report_to=args.report_to,
            response_length=args.max_response_length,  # Use max_response_length from args
            # Sample generation parameters
            num_sample_generations=getattr(args, 'num_sample_generations', 10),
            # PPO-specific parameters
            # max_length=getattr(args, 'ppo_max_length', 1024),
            # mini_batch_size=getattr(args, 'ppo_mini_batch_size', 1),
            # batch_size=getattr(args, 'ppo_batch_size', 8),
            # ppo_epochs=getattr(args, 'ppo_epochs', 4),
            # init_kl_coef=getattr(args, 'init_kl_coef', 0.2),
            # target_kl=getattr(args, 'target_kl', 6.0),
            # horizon=getattr(args, 'horizon', 10000),
            missing_eos_penalty=1.0, # penalty for missing EOS token
            gamma=getattr(args, 'gamma', 1.0),
            lam=getattr(args, 'lam', 0.95),
            cliprange=getattr(args, 'cliprange', 0.2),
            cliprange_value=getattr(args, 'cliprange_value', 0.2),
            vf_coef=getattr(args, 'vf_coef', 0.1),
        )
        
        # Create value model (can use the same reward model or a separate one)
        # value_model = CustomRewardModel(
        #     llm_reward_function=llm_reward_function,
        #     tokenizer=tokenizer,
        #     query_response_separator=getattr(args, 'query_response_separator', 'Assistant:')
        # )

        value_model = AutoModelForSequenceClassification.from_pretrained(
            args.value_model_id, 
            num_labels=1,
            device_map=args.device_map,
            use_cache=False,
            attn_implementation=args.attn_implementation,
            torch_dtype=model_dtype,
            quantization_config=bnb_config
        )
        
        # Prepare value model for k-bit training
        value_model = prepare_model_for_kbit_training(value_model)
        
        # Convert all non-quantized layers to the correct dtype
        for name, module in value_model.named_modules():
            # Convert linear layers that aren't quantized
            if isinstance(module, torch.nn.Linear):
                if hasattr(module, 'weight') and module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
            # Convert embeddings
            elif isinstance(module, torch.nn.Embedding):
                if module.weight.dtype == torch.float32:
                    module.to(compute_dtype)
            # Convert layer norms (including RMSNorm used in Llama models)
            elif 'norm' in module.__class__.__name__.lower():
                if hasattr(module, 'weight') and module.weight.dtype == torch.float32:
                    module.to(compute_dtype)

        # Apply PEFT to the policy model before passing to PPOTrainer
        policy_model = get_peft_model(policy_model, peft_config)
        
        # Fix dtype issues after PEFT wrapping
        # The model is now wrapped in PeftModelForCausalLM
        if hasattr(policy_model, 'base_model'):
            # Access the underlying model
            base_model = policy_model.base_model
            if hasattr(base_model, 'model'):
                actual_model = base_model.model
                
                # Ensure lm_head uses bfloat16
                if hasattr(actual_model, 'lm_head'):
                    if actual_model.lm_head.weight.dtype != compute_dtype:
                        actual_model.lm_head = actual_model.lm_head.to(compute_dtype)
                        print(f"Converted lm_head from {actual_model.lm_head.weight.dtype} to {compute_dtype}")
                
                # Also check the inner model (for LlamaModel structure)
                if hasattr(actual_model, 'model'):
                    inner_model = actual_model.model
                    
                    # Convert embeddings
                    if hasattr(inner_model, 'embed_tokens'):
                        if inner_model.embed_tokens.weight.dtype != compute_dtype:
                            inner_model.embed_tokens = inner_model.embed_tokens.to(compute_dtype)
                            print(f"Converted embed_tokens from {inner_model.embed_tokens.weight.dtype} to {compute_dtype}")
                    
                    # Convert layer norms
                    for name, module in inner_model.named_modules():
                        if 'norm' in module.__class__.__name__.lower():
                            if hasattr(module, 'weight') and module.weight.dtype != compute_dtype:
                                original_dtype = module.weight.dtype
                                module.to(compute_dtype)
                                print(f"Converted {name} from {original_dtype} to {compute_dtype}")

        # Use standard PPOTrainer
        # trainer = PPOTrainer(
        #     model=policy_model,
        #     ref_model=ref_model,
        #     reward_model=reward_model,
        #     value_model=value_model,
        #     args=training_args,
        #     train_dataset=train_dataset,
        #     eval_dataset=val_dataset,
        #     processing_class=tokenizer,
        #     peft_config=None,  # Already applied PEFT, so pass None here
        # )
        
        trainer = CustomPPOTrainer(
            model=policy_model,
            ref_model=ref_model,
            reward_model=reward_model,
            value_model=value_model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            processing_class=tokenizer,
            peft_config=None,  # Already applied PEFT, so pass None here
        )
    elif args.method == 'grpo':
        # BitsAndBytesConfig int-4 config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.load_in_4bit,
            bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=compute_dtype
        )
        
        # Load model
        model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            device_map=args.device_map,
            use_cache=False,  # Important for GRPO
            attn_implementation=args.attn_implementation if supports_flash_attention(0) else 'sdpa',
            torch_dtype=model_dtype,
            quantization_config=bnb_config
        )
        
        # LoRA config
        peft_config = LoraConfig(
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            bias=args.lora_bias,
            target_modules=args.lora_target_modules,
            task_type=args.lora_task_type,
        )
        
        # Initialize reward function based on type
        if getattr(args, 'reward_type') == 'reward_model':
            # Use RewardModelFunction for HuggingFace reward models
            reward_function_instance = RewardModelFunction(
                model_name=getattr(args, 'reward_model_name'),
                device=args.device_map,
                max_length=getattr(args, 'reward_max_length', 4096),
                use_quantization=getattr(args, 'use_quantization', False),
                normalize_scores=False  # Keep original 1-10 scores for GRPO training
            )
        else:
            # Check if we have a pre-fitted reward combiner to load
            reward_combiner_path = getattr(args, 'reward_combiner_path', None)
            reward_combiner = None

            if reward_combiner_path and os.path.exists(reward_combiner_path + "_model.pkl"):
                # Load the pre-fitted reward combiner
                from src.reward_combiner import create_reward_combiner

                # Create a combiner of the appropriate type
                reward_combiner = create_reward_combiner(
                    combiner_type=getattr(args, 'reward_combiner_type'),
                    objective_names=getattr(args, 'reward_objectives', []),
                    manual_weights=getattr(args, 'reward_manual_weights'),
                    manual_bias=getattr(args, 'reward_manual_bias', 0.0)
                )

                # Load the saved model
                reward_combiner.load(reward_combiner_path, reward_combiner.combination_function)
                reward_combiner.combination_function.objective_names = getattr(args, 'reward_objectives', [])
                print(f"Loaded pre-fitted reward combiner from: {reward_combiner_path}")

            if reward_combiner:
                print(f"  Loaded pre-fitted reward combiner of type: {type(reward_combiner.combination_function).__name__}")

            # Use LLMRewardFunction for LLM-based scoring
            reward_function_instance = LLMRewardFunction(
                model_name=getattr(args, 'reward_model_name', 'gpt-4o-mini'),
                use_api=getattr(args, 'use_api', True),
                combiner_type=getattr(args, 'reward_combiner_type', 'linear'),
                objective_names=getattr(args, 'reward_objectives'),
                manual_weights=getattr(args, 'reward_manual_weights'),
                manual_bias=getattr(args, 'reward_manual_bias', 0.0),
                reward_combiner=reward_combiner,  # Pass the loaded combiner if available
                device=args.device_map,
                max_length=getattr(args, 'reward_max_length', 4096),
                # dataset_type=getattr(args, 'dataset_type'),
                dataset_type=DATASET_NAMES_DICT[args.dataset_name],
                use_detailed_rubric=getattr(args, 'use_detailed_rubric', True),
                normalize_scores=False,  # Keep original 1-10 scores for GRPO training
                cache_dir=getattr(args, 'cache_dir', None),
                save_dir=output_dir,
                max_concurrent=getattr(args, 'max_concurrent', 50)
            )
        
        # GRPO training configuration
        training_args = GRPOConfig(
            output_dir=output_dir,
            num_train_epochs=args.num_train_epochs,
            per_device_train_batch_size=args.per_device_train_batch_size,
            per_device_eval_batch_size=args.per_device_eval_batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            max_steps=args.max_steps,
            # gradient_checkpointing=args.gradient_checkpointing,
            # optim=args.optim,
            # learning_rate=args.learning_rate,
            # max_grad_norm=args.max_grad_norm,
            # warmup_ratio=args.warmup_ratio,
            # lr_scheduler_type=args.lr_scheduler_type,
            logging_strategy=args.logging_strategy,
            logging_steps=args.logging_steps,
            save_strategy=args.save_strategy,
            save_steps=args.save_steps,
            # save_total_limit=args.save_total_limit,
            eval_strategy=args.eval_strategy,
            eval_steps=args.eval_steps,
            # bf16=args.bf16,
            # tf32=args.tf32,
            # push_to_hub=args.push_to_hub,
            report_to=args.report_to,
            log_completions=True,  # Enable logging of prompts and completions to wandb
            # GRPO-specific parameters
            # max_new_tokens=getattr(args, 'max_response_length', 512),  # Reuse existing parameter
            # temperature=getattr(args, 'temperature', 0.7),
            # num_samples=getattr(args, 'num_samples', 4),
            # rloo_k=getattr(args, 'rloo_k', 2),
        )
        
        # Determine custom save steps from GRPO_SAVE_CHECKPOINTS_DICT
        custom_save_steps = None
        if hasattr(args, 'run_name'):
            run_name = args.run_name
            # Check if run_name contains any key from GRPO_SAVE_CHECKPOINTS_DICT
            for key in GRPO_SAVE_CHECKPOINTS_DICT:
                if key in run_name:
                    custom_save_steps = GRPO_SAVE_CHECKPOINTS_DICT[key]
                    print(f"Found checkpoint save steps for {key}: {custom_save_steps}")
                    break

        if custom_save_steps is None:
            print(f"No specific checkpoint steps found for run_name '{getattr(args, 'run_name', 'N/A')}', will use default save_strategy")

        # Create GRPO trainer - pass the RewardFunction instance directly
        trainer = CustomGRPOTrainer(
            model=model,
            reward_funcs=reward_function_instance,  # Pass RewardFunction instance directly
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            processing_class=tokenizer,
            peft_config=peft_config,
            custom_save_steps=custom_save_steps,  # Pass custom save steps
        )
    else:
        raise NotImplementedError(f"Method {args.method} is not implemented.")

    return trainer