import os
import pandas as pd
from datasets import Dataset, load_dataset
from typing import Any, Dict, List, Optional, Union
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
import torch
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
from transformers.utils import PaddingStrategy
import time
import shutil
from pathlib import Path
from tqdm import tqdm
import json
# Add accelerate imports
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs


# ---------------------------
# Configuration
# ---------------------------

model_name = "meta-llama/Llama-3.2-1B-Instruct" 

# ---------------------------
# Load Model and Tokenizer
# ---------------------------
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if tokenizer.pad_token is None:
    # Use the EOS token as the pad token instead of adding a new one
    tokenizer.pad_token = tokenizer.eos_token
print(f"EOS token ID: {tokenizer.eos_token_id}")
print(len(tokenizer))
tokenizer.truncation_side = "left"
tokenizer.model_max_length = 4096

def setup_model(model_name: str):
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=1,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    model.config.pad_token_id = tokenizer.pad_token_id
    model.resize_token_embeddings(len(tokenizer))
    
    print("Using full model fine-tuning...")
        
    return model

# Replace model loading with:
model = setup_model(model_name)

# ---------------------------
# Data Preprocessing Function
# ---------------------------
def process_data_for_reward_model(row):
    # Get just the answer text
    answer = row["step_content"]
    #answer = prompt + answer  # Concatenate prompt and answer for the actual input
    
    # Use answer text only
    formatted_text = answer

    tokenized = tokenizer(
        formatted_text,
        truncation=True,
        max_length=4096,
        padding=False
    )
    
    # Bias score (1 if unbiased, 0 if biased)
    bias_score = 1 - row["biased"]  # Assume dataset provides this
    
    label = bias_score
    
    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"],
        "label": label
    }


# ---------------------------
# Load, Concatenate, and Split Dataset
# ---------------------------
print("Loading dataset...")
dataset = load_dataset("")
df = pd.concat([pd.DataFrame(dataset[split]) for split in dataset.keys()], ignore_index=True)
print(f"Size of dataset: {len(df)} rows")
print("Preparing datasets...")
# Use only 50% of the dataset
full_dataset = Dataset.from_pandas(df).shuffle(seed=45)
hf_dataset = full_dataset.select(range(len(full_dataset) // 2))  # Take only 50%
total_size = len(hf_dataset)
eval_size = int(total_size * 0.005)  # 10% for evaluation
eval_dataset = hf_dataset.select(range(eval_size))
hf_dataset = hf_dataset.select(range(eval_size, total_size))

print("Tokenizing datasets...")
train_dataset = hf_dataset.map(
    process_data_for_reward_model,
    remove_columns=hf_dataset.column_names,
    num_proc=8
)
print(f"Training dataset size: {len(train_dataset)}")

eval_dataset = eval_dataset.map(
    process_data_for_reward_model,
    remove_columns=eval_dataset.column_names,
    num_proc=8
)
print(f"Eval dataset size: {len(eval_dataset)}")

# ---------------------------
# Data Collator
# ---------------------------
@dataclass
class RewardDataCollatorWithPadding:
    tokenizer: AutoTokenizer
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Separate out labels
        labels = [f["label"] for f in features]
        
        # Prepare input features for padding
        batch_features = [
            {
                "input_ids": f["input_ids"],
                "attention_mask": f["attention_mask"],
            }
            for f in features
        ]
        
        batch = self.tokenizer.pad(
            batch_features,
            padding=True,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

        # Convert labels to torch
        batch["labels"] = torch.tensor(labels, dtype=torch.float32)
        return batch

# ---------------------------
# Custom Trainer for Reward Modeling
# ---------------------------
class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        device = inputs["input_ids"].device
        labels = inputs["labels"].to(device).to(torch.bfloat16)  # Match model dtype
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )
        
        # Calculate loss using manual BCE implementation
        logits = outputs.logits.squeeze(-1)  # Shape: (batch_size,)
        probs = torch.sigmoid(logits)
        # Add epsilon (1e-10) for numerical stability
        loss = labels * torch.log(probs + 1e-10) + (1 - labels) * torch.log(1 - probs + 1e-10)
        loss = -torch.mean(loss)  # Negative because we want to minimize negative log likelihood
        
        return (loss, outputs) if return_outputs else loss

# ---------------------------
# Training Arguments
# ---------------------------
# Create a unique temporary directory for this run
tmp_dir = f"/tmp/reward_model_{int(time.time())}"
os.makedirs(tmp_dir, exist_ok=True)

training_args = TrainingArguments(
    output_dir=tmp_dir,
    num_train_epochs=2,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    weight_decay=0.01,
    bf16=True,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=5000,
    save_total_limit=1,
    gradient_checkpointing=True,
    logging_steps=1,
    optim="adamw_torch",
    adam_beta1=0.9,
    adam_beta2=0.95,
    lr_scheduler_type="cosine",
    warmup_ratio=0.15,
    report_to="wandb",
    run_name=f"correctness-prm-full-{int(time.time())}",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    push_to_hub=True,
    hub_model_id=f"",
    ddp_find_unused_parameters=True,
    overwrite_output_dir=True
    # Use a different port for distributed training to avoid conflicts
)

# Add cleanup function
def cleanup_tmp_files():
    try:
        if os.path.exists(tmp_dir):
            shutil.rmtree(tmp_dir)
    except Exception as e:
        print(f"Warning: Could not clean up temporary directory: {e}")

# ---------------------------
# Initialize and Train
# ---------------------------
print("Initializing trainer...")
trainer = RewardTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer, max_length=409),
)

print("Starting training...")
try:
    trainer.train()
    trainer.push_to_hub()
    tokenizer.push_to_hub(training_args.hub_model_id)
finally:
    cleanup_tmp_files()  # Clean up temporary files even if training fails