from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import load_dataset
import argparse
from tqdm import tqdm
import numpy as np
import wandb
def calculate_log_loss_batch(model, tokenizer, batch_questions, batch_choices, append_options=False):
    """
    Calculate log loss and token counts for a batch of questions and choices, masking out the questions.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log_losses = []
    token_counts = []
    
    input_ids = []
    labels = []
    
    for question, choices in zip(batch_questions, batch_choices):
        full_question = question 
        if append_options:
            full_question += "\n" + "\n".join(choices)
        full_question += "\nAnswer: "
        question_tokens = tokenizer(full_question, return_tensors="pt", truncation=True, add_special_tokens=False)["input_ids"][0]
        question_length = len(question_tokens)


        for choice in choices:
            choice_tokens = tokenizer(choice, return_tensors="pt", truncation=True, add_special_tokens=False)["input_ids"][0]
            input_id = torch.cat((question_tokens, choice_tokens))
            label = torch.cat((torch.full_like(question_tokens, -100), choice_tokens))
            # truncate to 1024 tokens

            input_ids.append(input_id)
            labels.append(label)

    # Pad and collate inputs and labels for batching
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100).to(device)
    max_len = model.config.max_position_embeddings
    input_ids = input_ids[:, :max_len]
    labels = labels[:, :max_len]
    
    # 4 choices per question
    bsz = input_ids.size(0) // 4
    with torch.no_grad():
        outputs = model(input_ids=input_ids, labels=labels)
        shifted_logits = outputs.logits[:, :-1, :].contiguous()
        shifted_labels = labels[:, 1:].contiguous()
        loss_per_token = torch.nn.functional.cross_entropy(
            shifted_logits.view(-1, shifted_logits.size(-1)), 
            shifted_labels.view(-1), 
            reduction='none'
        ).view(shifted_labels.size())  # Shape: (batch_size, seq_len)

    # Process each choice and calculate log loss and token counts
    num_choices = len(batch_choices[0])  # Assumes consistent number of choices per question
    for i in range(len(batch_questions)):
        for j in range(num_choices):
            idx = i * num_choices + j
            choice_length = (shifted_labels[idx] != -100).sum().item()
            log_loss = loss_per_token[idx].sum().item()

            log_losses.append(log_loss)
            token_counts.append(choice_length)

    return log_losses, token_counts

@torch.no_grad()
def evaluate_perplexity_batch(dataset, model_name, tokenizer_name, batch_size=8, append_options=False):
    """
    Evaluate perplexity and log sum of log losses and token counts for the correct answer using batches.
    """
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    model.eval()
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Track overall log loss and token count for the correct answers
    total_log_loss = 0
    total_tokens = 0

    # Evaluate in batches
    results = []
    batch_questions = []
    batch_choices = []
    batch_answers = []
    for idx, row in tqdm(enumerate(dataset)):
        batch_questions.append(row["question"])
        batch_choices.append(row["choices"])
        batch_answers.append(row["answer"])
        
        # Process in batches
        if len(batch_questions) == batch_size or idx == len(dataset) - 1:
            log_losses, token_counts = calculate_log_loss_batch(model, tokenizer, batch_questions, batch_choices, append_options)
            
            for i, (question, choices, answer_index) in enumerate(zip(batch_questions, batch_choices, batch_answers)):

                curr_log_losses = log_losses[i * len(choices):(i + 1) * len(choices)]
                curr_token_counts = token_counts[i * len(choices):(i + 1) * len(choices)]
                answer_loss = curr_log_losses[answer_index]
                answer_tokens = curr_token_counts[answer_index]
                # curr_avg_losses = [loss / count for loss, count in zip(curr_log_losses, curr_token_counts)]
                curr_avg_losses = np.divide(curr_log_losses, curr_token_counts)
                curr_preds = np.argmin(curr_avg_losses) == answer_index
                results.append({
                    "question": question,
                    "choices": choices,
                    "answer_index": answer_index,
                    "log_losses": log_losses[i * len(choices):(i + 1) * len(choices)],
                    "token_counts": token_counts[i * len(choices):(i + 1) * len(choices)],
                    "avg_losses": curr_avg_losses,
                    "correct_log_loss": answer_loss,
                    "correct_token_count": answer_tokens,
                    "pred": curr_preds
                })
            
            # Reset batch
            batch_questions = []
            batch_choices = []
            batch_answers = []

    avg_pred = np.mean([result["pred"] for result in results])
    wandb.log({"accuracy": avg_pred})
    return results

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="gpt2")
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--tasks", type=str, default="gpqa-all")
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--append_options", action="store_true")
    parser.add_argument("--wandb_project", type=str, default="benchmarking")
    parser.add_argument("--wandb_tags", type=str, default=None)
    args = parser.parse_args()

    if args.tokenizer is None:
        args.tokenizer = args.model_name_or_path

    # Example usage
    if args.tasks == "gpqa-all":
        dataset = load_dataset("zekeZZ/gpqa_all", 'gpqa-all')["train"]  # Or another split like "test"
    elif args.tasks == "wmdp-bio":
        dataset = load_dataset("cais/wmdp", 'wmdp-bio')["test"]  # Or another split like "test"
    elif args.tasks == "wmdp-chem":
        dataset = load_dataset("cais/wmdp", 'wmdp-chem')["test"]  # Or another split like "test"
    elif args.tasks == "wmdp-cyber":
        dataset = load_dataset("cais/wmdp", 'wmdp-cyber')["test"]  # Or another split like "test"
    elif args.tasks == "tofu":
        dataset = load_dataset("zekeZZ/tofu_wiki_qa_shuffled")["train"]  # Or another split like "test"
    else:
        raise ValueError(f"Dataset {args.dataset} not supported")
    

    config = {
        "model_name_or_path": args.model_name_or_path,
        "tasks": args.tasks,
        "batch_size": args.batch_size,
        "append_options": args.append_options,
    }
    if args.wandb_tags is None:
        tags = ['perp_mcq_eval']
    else:
        tags = args.wandb_tags.split(',')
    wandb.init(project=args.wandb_project, config=config, tags=tags)
    results = evaluate_perplexity_batch(dataset, args.model_name_or_path, args.tokenizer, batch_size=args.batch_size, append_options=args.append_options)

    import json 
    from pathlib import Path
    save_dir = Path('results/perp_eval')
    save_dir.mkdir(parents=True, exist_ok=True)
    model_name = args.model_name_or_path.replace('/', '_')
    with open(save_dir / f'{model_name}_{args.tasks}.jsonl', 'w') as f:
        json.dump(results, f)
    # Print results
    avg_pred = np.mean([result["pred"] for result in results])
    print(f"Average prediction accuracy: {avg_pred}")
