import argparse
import os
from pathlib import Path
import torch
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from tqdm import tqdm
import time


def log(msg):
    print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}", flush=True)


def load_model_and_tokenizer(base_model_path, checkpoints_root_dir, checkpoint):
    log(f"Loading tokenizer and model for checkpoint {checkpoint}...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token

    if checkpoint == 0:
        model = AutoModelForCausalLM.from_pretrained(
            base_model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
        )
    else:
        checkpoint_dir = Path(checkpoints_root_dir) / f"checkpoint-{checkpoint}"
        log(f"Loading LoRA adapter from {checkpoint_dir}")
        model = AutoModelForCausalLM.from_pretrained(
            base_model_path, return_dict=True, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
        )
        model = PeftModel.from_pretrained(model, checkpoint_dir, adapter_name="default")
        model = model.merge_and_unload()

    model.eval()
    log("Model loaded and set to eval mode.")
    return tokenizer, model



def extract_prompts_winners_losers(batch, dataset_name):
    if dataset_name.startswith("trl-lib/ultrafeedback_binarized"):
        # in this case batch is a dict with 4 keys, two of which are 'chosen' and 'rejected'.
        # The values are lists in the length of the batch size, each element being a list of messages.
        # The last message in 'chosen' and 'rejected' is the assistant's response. The first message with role 'user' is the prompt.
        prompts = [dialogue[0]['content'] for dialogue in batch['chosen']]
        winners = [dialogue[-1]['content'] for dialogue in batch['chosen']]
        losers = [dialogue[-1]['content'] for dialogue in batch['rejected']]
    elif dataset_name.startswith("trl-lib/tldr-preference"):
        # in this case batch is a dict with 3 keys, prompt, chosen, rejected.
        prompts = batch['prompt']
        winners = batch['chosen']
        losers = batch['rejected']
    else:
        raise ValueError(f"Unsupported dataset {dataset_name}")
    return prompts, winners, losers


def compute_logprobs_batch(model, tokenizer, prompts, responses):
    input_texts = [p + r + tokenizer.eos_token for p, r in zip(prompts, responses)]
    enc = tokenizer(input_texts, return_tensors="pt", padding=True).to(model.device)

    with torch.no_grad():
        outputs = model(**enc)
        logprobs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)

    results = []
    for j, (prompt, response) in enumerate(zip(prompts, responses)):
        prompt_len = len(tokenizer(prompt)["input_ids"])
        input_ids = enc["input_ids"][j]

        target_ids = input_ids[prompt_len:]
        target_logprobs = logprobs[j, prompt_len - 1:-1, :]

        assert target_logprobs.shape[0] == target_ids.shape[0], (
            f"Mismatch: logits {target_logprobs.shape}, targets {target_ids.shape}"
        )

        token_logprobs = target_logprobs.gather(1, target_ids.unsqueeze(-1)).squeeze(-1)
        total_logprob = token_logprobs.sum().item()
        results.append(total_logprob)

    return results


def save_checkpoint_csv(output_dir, checkpoint, winner_logprobs, loser_logprobs):
    winner_path = output_dir / f"winner_logprobs_{checkpoint}.csv"
    loser_path = output_dir / f"loser_logprobs_{checkpoint}.csv"

    pd.DataFrame({str(checkpoint): winner_logprobs}).to_csv(winner_path)
    pd.DataFrame({str(checkpoint): loser_logprobs}).to_csv(loser_path)

    log(f"Saved results to {winner_path} and {loser_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoints-root-dir", type=str, required=True)
    parser.add_argument("--checkpoint", type=int, required=True)
    parser.add_argument("--base-model", type=str, required=True)
    parser.add_argument("--split", type=str, default="train")
    parser.add_argument("--dataset-name", type=str, default="trl-lib/tldr-preference")
    parser.add_argument("--output-dir", type=str, default="logprobs")
    parser.add_argument("--subset-size", type=int, default=None, help="Optional number of samples to subsample with seed 0")
    parser.add_argument("--batch-size", type=int, default=8, help="Batch size for processing")
    args = parser.parse_args()

    log("Starting logprob computation script...")
    checkpoint_root_name = Path(args.checkpoints_root_dir).resolve().name
    full_output_dir = Path(args.output_dir) / checkpoint_root_name
    full_output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer, model = load_model_and_tokenizer(args.base_model, args.checkpoints_root_dir, args.checkpoint)
    dataset = load_dataset(args.dataset_name, split=args.split)
    log(f"Loaded dataset split '{args.split}' with {len(dataset)} examples")

    if args.subset_size is not None:
        dataset = dataset.shuffle(seed=0).select(range(args.subset_size))
        log(f"Subsampled to {args.subset_size} examples")

    winner_logprobs = []
    loser_logprobs = []

    for i in tqdm(range(0, len(dataset), args.batch_size), desc="Computing logprobs in batches"):
        batch = dataset[i: i + args.batch_size]

        prompts, winners, losers = extract_prompts_winners_losers(batch, args.dataset_name)

        winner_logprobs.extend(compute_logprobs_batch(model, tokenizer, prompts, winners))
        loser_logprobs.extend(compute_logprobs_batch(model, tokenizer, prompts, losers))

        if (i + args.batch_size) % 1000 < args.batch_size:
            log(f"Processed {i + len(prompts)} examples...")

    save_checkpoint_csv(full_output_dir, args.checkpoint, winner_logprobs, loser_logprobs)
    log("All done.")


if __name__ == "__main__":
    main()
