from os.path import exists, join, isdir
from dataclasses import dataclass, field
import sys
from tqdm import tqdm
import torch
import os

import argparse
import json
import yaml
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    set_seed,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    LlamaTokenizer
)
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model,
    PeftModel
)
import re
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
from peft import PeftModel, PeftConfig, LoraConfig, TaskType
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from datasets import load_dataset, DatasetDict
from openai import OpenAI

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import src.constants as constants
from src.custom_rewards import LLMRewardFunction, RewardModelFunction
# from src.setup_datasets import load_datasets, prepare_dataset
from src.setup_datasets_clean import load_datasets
from src.reward_combiner_logging import log_test_reward_combiner
import string
from torch.utils.data import DataLoader
import wandb
import logging

def parse_args():
    parser = argparse.ArgumentParser(description='PPO Training Script')
    
    # Basic arguments
    parser.add_argument('--base_model_path', type=str, default='meta-llama/Llama-3.1-8B', help='Base model path for PPO config')
    parser.add_argument('--model_id_path', type=str, default='', help='Path to the SFT model')
    parser.add_argument('--model_save_path', type=str, default='', help='Path to save the PPO model')
    parser.add_argument('--dataset_name', type=str, default='Anthropic/hh-rlhf', help='Dataset name')
    parser.add_argument('--dataset_max_length', type=int, default=512, help='Dataset max length')
    parser.add_argument('--multi_turn', type=bool, default=False, help='Whether to use multi-turn dialogues')
    parser.add_argument('--method', type=str, default='ppo', help='Method name for logging')
    parser.add_argument('--wandb_run_name', type=str, default='ppo-hh-thoroughness-2', help='Wandb run name')
    parser.add_argument('--local_wandb_dir', type=str, default='', help='Local directory for wandb logs')
    
    # Training parameters
    parser.add_argument('--use_sanity_check', type=bool, default=False, help='Use sanity check mode')
    parser.add_argument('--max_ppo_steps', type=int, default=2000, help='Maximum PPO training steps')
    parser.add_argument('--eval_freq', type=int, default=25, help='Evaluation frequency')
    parser.add_argument('--log_filename', type=str, default='logs.txt', help='Log filename')
    
    # PPO hyperparameters
    parser.add_argument('--learning_rate', type=float, default=2e-6, help='Learning rate')
    parser.add_argument('--max_ppo_epochs', type=int, default=4, help='Maximum PPO epochs')
    parser.add_argument('--mini_batch_size', type=int, default=4, help='Mini batch size')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
    parser.add_argument('--init_kl_coef', type=float, default=0.2, help='Initial KL coefficient')
    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Maximum gradient norm')
    parser.add_argument('--target_kl', type=float, default=0.2, help='Target KL divergence')
    parser.add_argument('--adap_kl_ctrl', type=bool, default=True, help='Use adaptive KL control')
    parser.add_argument('--use_score_scaling', type=bool, default=False, help='Use score scaling')
    parser.add_argument('--use_score_norm', type=bool, default=False, help='Use score normalization')
    parser.add_argument('--early_stopping', type=bool, default=False, help='Use early stopping')
    parser.add_argument('--score_clip', type=float, default=None, help='Score clipping value')
    parser.add_argument('--whiten_rewards', type=bool, default=True, help='Whiten rewards during PPO training')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Gradient accumulation steps')
    parser.add_argument('--use_custom_scheduler', type=bool, default=False, help='Use custom scheduler (linear warmup + cosine annealing)')
    parser.add_argument('--warmup_ratio', type=float, default=0.02, help='Linear warmup + cosine annealing warmup ratio')

    
    # Generation parameters
    parser.add_argument('--max_new_tokens', type=int, default=512, help='Maximum new tokens to generate')
    
    # Reward configuration
    parser.add_argument('--ground_truth_type', type=str, default='llm_function', help='Type of ground truth: "llm_function" or "reward_model"')
    parser.add_argument('--ground_truth_objectives', type=str, default='["thoroughness"]', help='JSON list of objectives (for llm_function)')
    parser.add_argument('--ground_truth_weights', type=str, default='{"thoroughness": 1.0}', help='JSON dict of objective weights (for llm_function)')
    parser.add_argument('--reward_model_name', type=str, default='gpt-4o-mini', help='Model name for LLM scoring (llm_function) or HuggingFace model ID (reward_model)')
    parser.add_argument('--use_api', action='store_true', default=True, help='Use API for reward scoring (llm_function only)')
    parser.add_argument('--reward_combiner_type', type=str, default='linear', help='Reward combiner type (llm_function only)')
    parser.add_argument('--manual_bias', type=float, default=0.0, help='Manual bias for reward combiner (llm_function only)')
    parser.add_argument('--use_detailed_rubric', action='store_true', default=True, help='Use detailed rubric for scoring (llm_function only)')
    parser.add_argument('--reward_combiner_path', type=str, default=None, help='Path to a pre-fitted reward combiner to load')
    parser.add_argument('--reward_model_max_length', type=int, default=2048, help='Max length for reward model input (reward_model only)')
    parser.add_argument('--use_quantization', action='store_true', default=True, help='Use 4-bit quantization for reward model (reward_model only)')
    parser.add_argument('--cache_dir', type=str, default=None, help='Path to the local cache directory for examples and rubrics')
    parser.add_argument('--max_concurrent', type=int, default=50, help='Maximum concurrent API calls for async reward scoring')

    # EOS penalty
    parser.add_argument('--missing_eos_penalty', type=float, default=1.0, help='Penalty to subtract from reward when response lacks EOS token')
    
    # Config file support
    parser.add_argument('--config_file', type=str, default=None, help='Path to config file (json or yaml)')
    
    args = parser.parse_args()
    
    # Load config file if provided
    if args.config_file:
        with open(args.config_file) as f:
            override = json.load(f) if args.config_file.endswith(".json") else yaml.safe_load(f)
        for k, v in override.items():
            setattr(args, k, v)
    
    # Parse JSON string arguments
    if isinstance(args.ground_truth_objectives, str):
        args.ground_truth_objectives = json.loads(args.ground_truth_objectives)
    if isinstance(args.ground_truth_weights, str):
        args.ground_truth_weights = json.loads(args.ground_truth_weights)
    
    return args
        
# Create a custom formatter that only replaces some fields
class PartialFormatter(string.Formatter):
    def get_value(self, key, args, kwargs):
        if isinstance(key, str):
            return kwargs.get(key, '{' + key + '}')
        else:
            return string.Formatter.get_value(self, key, args, kwargs)

def freeze_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            print(name)
            trainable_params += param.numel()
            param.requires_grad = False
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}. Now froze all params."
    )

def supports_flash_attention(device_id):
    """Check if a GPU supports FlashAttention."""
    if not torch.cuda.is_available():
        return False
    major, minor = torch.cuda.get_device_capability(device_id)
    
    # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0

    return is_sm8x or is_sm90

print('Flash Attention Supported:', supports_flash_attention(0))

# Parse arguments
args = parse_args()

# Use arguments from argparse
base_model_path = args.base_model_path
model_id_path = args.model_id_path
model_save_path = args.model_save_path
dataset_name = args.dataset_name
use_sanity_check = args.use_sanity_check
max_ppo_steps = args.max_ppo_steps
eval_freq = args.eval_freq
log_filename = args.log_filename
wandb_run_name = args.wandb_run_name
ground_truth_objectives = args.ground_truth_objectives
ground_truth_weights = args.ground_truth_weights
cache_dir = args.cache_dir

# Create save directory if it doesn't exist
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)

# Save config to model_save_path
config_save_path = os.path.join(model_save_path, 'config.yaml')
with open(config_save_path, 'w') as f:
    yaml.dump(vars(args), f, default_flow_style=False, sort_keys=False)
print(f"Config saved to {config_save_path}")

val_log_path = os.path.join(model_save_path, log_filename)

# Log initial configuration
with open(val_log_path, 'a', encoding='utf-8') as f:
    f.write(f"--- PPO Run ---\n")
    f.write(f"Model: {model_id_path}\n")
    f.write(f"Dataset: {dataset_name}\n")
    if args.ground_truth_type == 'reward_model':
        f.write(f"Reward Model: {args.reward_model_name}\n")
    else:
        f.write(f"Ground-truth objectives:\n")
        for obj, weight in ground_truth_weights.items():
            f.write(f"  - {obj}: {weight:.2f}\n")
    f.write("="*80 + "\n\n")


# LOAD MODEL AND TOKENIZER
tokenizer = AutoTokenizer.from_pretrained(
        model_id_path,
        trust_remote_code=True,
    )
# tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# tokenizer.padding_side = "left"
if tokenizer.chat_template is None:
    print("Setting chat template to SIMPLE_QUERY_CHAT_TEMPLATE")
    tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE

model = AutoModelForCausalLM.from_pretrained(
        model_id_path,
        use_safetensors=True,
        use_cache=False,
        torch_dtype=torch.bfloat16,
        attn_implementation='flash_attention_2' if supports_flash_attention(0) else 'sdpa',
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'))

if len(tokenizer) != model.config.vocab_size:
    model.resize_token_embeddings(len(tokenizer))

peft_config = LoraConfig(
            lora_alpha=128,
            lora_dropout=0.05,
            r=256,
            bias="none",
            target_modules="all-linear",
            task_type="CAUSAL_LM", 
            modules_to_save=["embed_tokens", "lm_head"],
        )
peft_model = get_peft_model(model, peft_config)
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model,
                                                               torch_dtype=torch.bfloat16,
                                                               is_trainable=True,
                                                               use_cache=False,
                                                               attn_implementation='flash_attention_2' if supports_flash_attention(0) else 'sdpa',
                                                               )    
ref_model = create_reference_model(ppo_model)

# INITIALIZE CUSTOM REWARD FUNCTION
ground_truth_type = args.ground_truth_type

if ground_truth_type == 'reward_model':
    print("\nInitializing RewardModelFunction for reward scoring:")
    print(f"  Model: {args.reward_model_name}")
    print(f"  Max length: {args.reward_model_max_length}")
    print(f"  Use quantization: {args.use_quantization}")
    
    custom_reward_function = RewardModelFunction(
        model_name=args.reward_model_name,
        device="auto",
        max_length=args.reward_model_max_length,
        use_quantization=args.use_quantization,
        normalize_scores=False  # Keep original scores for PPO training
    )
    freeze_trainable_parameters(custom_reward_function.model)
else:
    # Check if we have a pre-fitted reward combiner to load
    reward_combiner = None
    if args.reward_combiner_path and os.path.exists(args.reward_combiner_path + "_model.pkl"):
        # Load the pre-fitted reward combiner
        from src.reward_combiner import create_reward_combiner

        # Create a combiner of the appropriate type
        reward_combiner = create_reward_combiner(
            combiner_type=args.reward_combiner_type,
            objective_names=ground_truth_objectives,
            manual_weights=ground_truth_weights,
            manual_bias=args.manual_bias
        )

        # Load the saved model
        reward_combiner.load(args.reward_combiner_path, reward_combiner.combination_function)
        reward_combiner.combination_function.objective_names = ground_truth_objectives
        print(f"Loaded pre-fitted reward combiner from: {args.reward_combiner_path}")

        # Log detailed analysis of the loaded reward combiner
        # Create a logger that writes to the validation log file
        file_logger = logging.getLogger('reward_combiner_test')
        file_logger.setLevel(logging.INFO)

        # Remove any existing handlers to avoid duplicates
        file_logger.handlers = []

        # Create a file handler that writes to the same log file as validation
        file_handler = logging.FileHandler(val_log_path, mode='a', encoding='utf-8')
        file_handler.setLevel(logging.INFO)

        # Create a formatter without timestamp (to match existing log style)
        formatter = logging.Formatter('%(message)s')
        file_handler.setFormatter(formatter)

        # Add the handler to the logger
        file_logger.addHandler(file_handler)

        # Log the reward combiner analysis
        file_logger.info("\n" + "="*80)
        file_logger.info("LOADED REWARD COMBINER ANALYSIS")
        file_logger.info("="*80)
        log_test_reward_combiner(file_logger, reward_combiner)

        # Clean up the handler to avoid keeping file open
        file_handler.close()
        file_logger.removeHandler(file_handler)

        print(f"Reward combiner analysis logged to: {val_log_path}")

    # Use LLMRewardFunction (existing behavior)
    print("\nInitializing LLMRewardFunction with ground-truth objectives:")
    for obj, weight in ground_truth_weights.items():
        print(f"  - {obj}: {weight:.2f}")
    if reward_combiner:
        print(f"  Loaded pre-fitted reward combiner of type: {type(reward_combiner.combination_function).__name__}")

    custom_reward_function = LLMRewardFunction(
        model_name=args.reward_model_name,  # Use model from args for scoring
        use_api=args.use_api,
        combiner_type=args.reward_combiner_type,
        objective_names=ground_truth_objectives,
        manual_weights=ground_truth_weights,
        manual_bias=args.manual_bias,
        reward_combiner=reward_combiner,  # Pass the loaded combiner if available
        device="auto",
        max_length=4096,
        dataset_type=constants.DATASET_NAMES_DICT[dataset_name],
        use_detailed_rubric=args.use_detailed_rubric,
        normalize_scores=False,  # Keep original 1-10 scores for PPO training
        cache_dir=cache_dir,
        save_dir=model_save_path,
        max_concurrent=args.max_concurrent
    )

print("Custom reward function initialized successfully\n")

# LOAD DATASET
# if dataset_name == 'openai/summarize_from_feedback':
#     raw_datasets = load_dataset(dataset_name, 'comparisons')
#     train_dataset = raw_datasets["train"]
#     # Select more samples initially to account for filtering and ensure uniqueness
#     eval_dataset_raw = raw_datasets["validation"].select(range(min(16, len(raw_datasets["validation"]))))
#     # Deduplicate based on post content to ensure unique samples
#     seen_posts = set()
#     unique_indices = []
#     for i, sample in enumerate(eval_dataset_raw):
#         post_content = sample["info"]["post"]
#         if post_content not in seen_posts:
#             seen_posts.add(post_content)
#             unique_indices.append(i)
#             if len(unique_indices) >= 32:
#                 break
#     eval_dataset = eval_dataset_raw.select(unique_indices)
#     print(f"Selected {len(eval_dataset)} unique evaluation samples from Reddit TLDR dataset")
# elif dataset_name == 'Anthropic/hh-rlhf':
#     raw_datasets = load_dataset(dataset_name)
#     train_dataset = raw_datasets["train"]
#     # Select more samples initially to account for filtering and ensure uniqueness
#     eval_dataset = raw_datasets["test"].select(range(min(16, len(raw_datasets["test"]))))

# train_dataset = prepare_dataset(train_dataset, tokenizer, dataset_name)
# eval_dataset = prepare_dataset(eval_dataset, tokenizer, dataset_name)
# train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 612)
# eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 612)
# assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"

train_dataset, eval_dataset = load_datasets(args, tokenizer, val_dataset_size=16)

# Create and load optimizer
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ppo_model.parameters()), lr=args.learning_rate)

# --- NEW: Define Linear Warmup + Cosine Annealing Scheduler ---
# 1. Calculate total steps
# num_update_steps_per_epoch = len(train_dataset) // args.batch_size
total_training_steps = max_ppo_steps
warmup_steps = int(total_training_steps * args.warmup_ratio)

print(f"Total training steps: {total_training_steps}, Warmup steps: {warmup_steps}")

# 2. Define Schedulers
# Linear Warmup: Increases LR from (lr * start_factor) to lr
if args.use_custom_scheduler:
    # scheduler_warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
    scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
    # Cosine Annealing: Decays LR from lr to 0 (or min_lr)
    # T_0 is the number of steps for the first restart. We set it to remaining steps.
    # scheduler_cosine = CosineAnnealingWarmRestarts(optimizer, T_0=total_training_steps - warmup_steps, T_mult=1)
    # Sequential: Chains them together
    # scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[warmup_steps])

def collator(data):
    return {key: [d[key] for d in data] for key in data[0]}

# def get_gpt4o_mini_score(prompt, response, objective="thoroughness"):
#     """
#     Use GPT-4o-mini API to score a response based on an objective.
    
#     Args:
#         prompt: The original prompt/query
#         response: The response to evaluate
#         objective: The objective to score (default: "conciseness")
    
#     Returns:
#         Score between 1.0 and 10.0
#     """
#     # Initialize OpenAI client
#     # api_key = os.environ.get('OPENAI_API_KEY')
#     api_key = constants.OPENAI_API_KEY
#     if not api_key:
#         print("Warning: OPENAI_API_KEY not found. Using default score.")
#         return 5.0
    
#     client = OpenAI(api_key=api_key)
    
#     # Create scoring prompt based on objective
#     if objective == "tldr_quality":
#         formatter = PartialFormatter()
#         scoring_prompt = formatter.format(
#             constants.SCORING_PROMPT_BASE_TEMPLATE,
#             objective='conciseness',
#             objective_description=constants.SCORING_OBJECTIVE_DESCRIPTIONS['conciseness'],
#             scoring_rubric=constants.SCORING_RUBRICS_HH['conciseness'],
#             query=prompt,
#             response=response
#         )
#     elif objective == 'thoroughness':
#         formatter = PartialFormatter()
#         scoring_prompt = formatter.format(
#             constants.SCORING_PROMPT_BASE_TEMPLATE,
#             objective='thoroughness',
#             objective_description=constants.SCORING_OBJECTIVE_DESCRIPTIONS['thoroughness'],
#             scoring_rubric=constants.SCORING_RUBRICS_HH['thoroughness'],
#             query=prompt,
#             response=response
#         )
#     elif objective == "conciseness":
#         scoring_prompt = f"""Rate the conciseness of the following response on a scale from 1-10.
# 1 = extremely verbose and repetitive
# 10 = perfectly concise and to the point

# Query: {prompt}
# Response: {response}

# Provide only a numerical score from 1-10."""
#     else:
#         # Default scoring prompt
#         scoring_prompt = f"""Rate the quality of the following response on a scale from 1-10.

# Query: {prompt}
# Response: {response}

# Provide only a numerical score from 1-10."""
    
#     try:
#         api_response = client.chat.completions.create(
#             model="gpt-4o-mini",
#             # model="gpt-5",
#             messages=[
#                 {"role": "system", "content": "You are an expert evaluator. Respond only with a numerical score from 1-10."},
#                 {"role": "user", "content": scoring_prompt}
#             ],
#             temperature=0.1,
#             max_tokens=10
#             # max_completion_tokens=10
#         )
#         generated_text = api_response.choices[0].message.content.strip()
        
#         # Extract numerical score
#         numbers = re.findall(r'\b\d+(?:\.\d+)?\b', generated_text)
#         if numbers:
#             score = float(numbers[0])
#             # Clamp score to valid range
#             score = max(1.0, min(10.0, score))
#             return score
#         else:
#             print(f"Warning: Could not extract score from '{generated_text}'. Using default score of 5.0")
#             return 5.0
#     except Exception as e:
#         print(f"Error calling OpenAI API: {e}. Using default score of 5.0")
#         return 5.0

def get_score_api(queries, responses, task_list=None, sanity_check=False):
    """
    Get scores using the custom reward function with ground-truth objectives.
    
    Args:
        queries: List of original queries
        responses: List of responses to evaluate
        task_list: List of tasks (unused, kept for compatibility)
        sanity_check: If True, use simple length-based scoring
    
    Returns:
        List of scores
    """
    if sanity_check:
        # Simple sanity check scoring based on response length
        scores = []
        for response in responses:
            # scores.append(torch.tensor(len(response) / 100.0))
            scores.append(torch.tensor(len(response) / 1.0))
        return scores, None
    else:
        # Use custom reward function with ground-truth objectives
        # Returns: (normalized_rewards, denormalized_rewards, objective_scores_list)
        normalized_rewards, denormalized_rewards, obj_scores_list = custom_reward_function.compute_reward(queries, responses)
        rewards_tensor = denormalized_rewards.float()
        # Convert to list of individual tensors for PPO trainer
        scores = [rewards_tensor[i] for i in range(len(rewards_tensor))]
        return scores, obj_scores_list


#Set up PPO trainer - use parsed arguments
learning_rate = args.learning_rate
max_ppo_epochs = args.max_ppo_epochs
mini_batch_size = args.mini_batch_size
batch_size = args.batch_size
init_kl_coef = args.init_kl_coef
max_grad_norm = args.max_grad_norm
target_kl = args.target_kl
adap_kl_ctrl = args.adap_kl_ctrl
use_score_scaling = args.use_score_scaling
use_score_norm = args.use_score_norm
score_clip = args.score_clip
early_stopping = args.early_stopping
whiten_rewards = args.whiten_rewards
gradient_accumulation_steps = args.gradient_accumulation_steps

# Initialize wandb with custom run name
wandb.init(
    project=constants.WANDB_PROJECT_NAME,  # Uses project from constants
    group='train_ppo',
    name=wandb_run_name,  # Custom run name
    tags=['ppo', base_model_path, dataset_name],
    config=vars(args),
    # config={
    #     "model_id": model_id_path,
    #     "dataset": dataset_name,
    #     "objectives": ground_truth_objectives,
    #     "weights": ground_truth_weights,
    #     "eval_freq": eval_freq,
    #     "learning_rate": learning_rate,
    #     "batch_size": batch_size,
    #     "mini_batch_size": mini_batch_size,
    #     "max_ppo_steps": max_ppo_steps,
    #     "max_new_tokens": args.max_new_tokens,
    # },
    dir=args.local_wandb_dir
)

config = PPOConfig(
    model_name=base_model_path,    
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    log_with='wandb',
    reward_model=None,
    init_kl_coef=init_kl_coef,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size,
    max_grad_norm=max_grad_norm,
    adap_kl_ctrl=adap_kl_ctrl,
    target_kl=target_kl,
    whiten_rewards=whiten_rewards,
    early_stopping=early_stopping,
    gradient_accumulation_steps=gradient_accumulation_steps,
    use_score_scaling=use_score_scaling,
    use_score_norm=use_score_norm,
    score_clip=None if score_clip == 'None' else float(score_clip),
)

# min_input_length = 500
# max_input_length = 550
# dataset = build_dataset(config, dataset_name=dataset_name, input_min_text_length=min_input_length, input_max_text_length=max_input_length)
ppo_trainer = PPOTrainer(config=config, 
                         model=ppo_model, 
                         ref_model=ref_model, 
                         tokenizer=tokenizer, 
                         dataset=train_dataset,
                         data_collator=collator,
                         optimizer=optimizer,
                         lr_scheduler=scheduler if args.use_custom_scheduler else None
                         )

# wandb.watch(ppo_model, log="gradients", log_freq=20)
                         
#Set up and implement training
generation_kwargs = {
    # "min_length": 5,
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    # "eos_token_id": -1,
    "max_new_tokens": args.max_new_tokens
}

num_steps_per_epoch = len(ppo_trainer.dataloader)
print(f"Number of steps per epoch: {num_steps_per_epoch}")

eval_dataloader = DataLoader(eval_dataset, batch_size=config.batch_size, collate_fn=collator)

# Check if wandb_run_name contains any key from PPO_SAVE_CHECKPOINTS_DICT
ppo_save_idxs = None
for key in constants.PPO_SAVE_CHECKPOINTS_DICT:
    if key in wandb_run_name:
        ppo_save_idxs = constants.PPO_SAVE_CHECKPOINTS_DICT[key]
        print(f"Found checkpoint save indices for {key}: {ppo_save_idxs}")
        break

if ppo_save_idxs is None:
    print(f"No specific checkpoint indices found for {wandb_run_name}, will save at regular eval_freq intervals")

# Track validation score from previous iteration to add to stats
prev_val_score = None

for batch_idx, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    print(f"\n--- PPO Training Step {batch_idx + 1}/{num_steps_per_epoch} ---")
    if batch_idx >= max_ppo_steps:
        print(f"Reached maximum PPO steps ({max_ppo_steps}). Ending training.")
        break
    query_tensors = batch["input_ids"]

    avg_val_score = None

    # Get response from the policy model
    # response_tensors = []
    # for query in query_tensors:
    #     response = ppo_trainer.generate(query, return_prompt=False, **generation_kwargs)
    #     response_tensors.append(response.squeeze())
    # batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    response_tensors_padded = ppo_trainer.generate(query_tensors, return_prompt=False, **generation_kwargs)
    response_tensors = []
    for i in range(len(response_tensors_padded)):
        response_tensors.append(response_tensors_padded[i])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    batch['batch_idx'] = [batch_idx] * len(response_tensors)

    # Compute rewards using custom reward function
    texts = batch["response"] # list of strings
    prompts = [tokenizer.decode(q) for q in query_tensors]
    
    # Get rewards from custom function with ground-truth objectives
    rewards, obj_scores_list = get_score_api(prompts, texts, sanity_check=use_sanity_check)
    
    # Apply missing EOS penalty if responses don't end with EOS token
    if args.missing_eos_penalty > 0:
        for i, response_tensor in enumerate(response_tensors):
            # Check if the last token (excluding padding) is EOS token
            if response_tensor[-1] != tokenizer.eos_token_id:
                # Find last non-padding token
                non_pad_tokens = response_tensor[response_tensor != tokenizer.pad_token_id]
                if len(non_pad_tokens) > 0 and non_pad_tokens[-1] != tokenizer.eos_token_id:
                    rewards[i] = rewards[i] - args.missing_eos_penalty
                    print('Applied missing eos penalty to the following response:')
                    print(tokenizer.decode(response_tensor))
    
    # Log sample reward for monitoring
    if batch_idx % 10 == 0:
        print(f"Sample reward: {rewards[0].item():.4f}")
        print(f"Response for first sample: {texts[0]}")
        print(f"Objective scores for first sample: {obj_scores_list[0] if obj_scores_list else 'N/A'}")

    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

    # Add extra metrics to stats dict before log_stats to keep wandb steps in sync
    # This ensures all metrics are logged at the same step
    if ppo_trainer.accelerator.is_main_process:
        if args.use_custom_scheduler:
            stats["env/learning_rate"] = scheduler.get_last_lr()[0]
        # Add previous iteration's validation score (if any) to current stats
        if prev_val_score is not None:
            stats["env/val_reward_mean"] = prev_val_score
            prev_val_score = None  # Reset after logging

    ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=['query', 'response', 'batch_idx'])

    # Determine whether to save checkpoint
    should_save = False
    if ppo_save_idxs is not None:
        # If we have specific indices, only save at those points
        should_save = batch_idx in ppo_save_idxs
    else:
        # Otherwise, save at regular eval_freq intervals
        should_save = (batch_idx % eval_freq == 0)

    # if (batch_idx % eval_freq == 0): # evaluate at eval_freq intervals
    if should_save: # only evaluate and save at specified points
        print(f"\n--- Running Validation at Batch Step {batch_idx} ---")

        # Only save checkpoint if should_save is True
        if should_save:
            curr_save_dir = os.path.join(model_save_path, f"checkpoint-{batch_idx}")
            if not os.path.exists(curr_save_dir):
                os.makedirs(curr_save_dir)
            ppo_trainer.save_pretrained(curr_save_dir)
            print('Model saved to', curr_save_dir)
        else:
            print(f'Skipping checkpoint save at step {batch_idx} (not in save indices)')

        val_results = []
        ppo_model.eval() # Set model to evaluation mode
        with torch.no_grad():
            for val_batch_idx, val_batch in tqdm(enumerate(eval_dataloader), desc=f"Validation", total=len(eval_dataloader)):
                print(f"\n--- Validation Batch {val_batch_idx + 1}/{len(eval_dataloader)} ---")
                val_query_tensors = val_batch["input_ids"]
                
                # Generate responses for the validation batch
                val_response_tensors = ppo_trainer.generate(
                    val_query_tensors, 
                    return_prompt=False, 
                    **generation_kwargs
                )
                
                val_prompts = [tokenizer.decode(q, skip_special_tokens=True) for q in val_query_tensors]
                val_responses = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in val_response_tensors]
                
                # Score the validation responses using custom reward function
                val_rewards, val_obj_scores_list = get_score_api(val_prompts, val_responses, sanity_check=use_sanity_check)
                
                # Store results
                for i in range(len(val_prompts)):
                    val_results.append({
                        "prompt": val_prompts[i],
                        "response": val_responses[i],
                        "score": val_rewards[i].item(),
                        "obj_scores": val_obj_scores_list[i] if val_obj_scores_list else None
                    })

        ppo_model.train() # Set model back to training mode

        # Save validation results to a text file
        if val_results:
            os.makedirs(model_save_path, exist_ok=True)
            val_log_path = os.path.join(model_save_path, log_filename)
            
            avg_val_score = sum(item['score'] for item in val_results) / len(val_results)

            # Store validation score to be logged with next iteration's stats
            # This keeps wandb steps in sync (logged via stats dict in log_stats)
            prev_val_score = avg_val_score

            with open(val_log_path, 'a', encoding='utf-8') as f:
                f.write(f"--- Validation Summary for Batch Step {batch_idx} ---\n")
                f.write(f"Average Score: {avg_val_score:.4f}\n")
                if args.ground_truth_type == 'reward_model':
                    f.write(f"Reward Model: {args.reward_model_name}\n")
                else:
                    f.write(f"Ground-truth objectives:\n")
                    for obj, weight in ground_truth_weights.items():
                        f.write(f"  - {obj}: {weight:.2f}\n")
                f.write("="*80 + "\n\n")
                
                for item in val_results:
                    f.write(f"Prompt:\n{item['prompt']}\n\n")
                    f.write(f"Response:\n{item['response']}\n\n")
                    f.write(f"Score: {item['score']:.4f}\n")
                    if item['obj_scores']:
                        f.write("Objective Scores:\n")
                        for obj_name, obj_score in item['obj_scores'].items():
                            f.write(f"  - {obj_name}: {obj_score:.4f}\n")
                    f.write("-" * 80 + "\n")
            print(f"Validation results for epoch {batch_idx} saved to {val_log_path}")

    # Note: Learning rate and validation metrics are now added to stats dict
    # before log_stats() to keep wandb steps in sync

print("PPO training complete.")  
