import wandb
import gc
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
from lib.repetition import random_sequence_repetition_accuracy, natural_text_repetition_accuracy

def load_local_checkpoint(checkpoint_dir):
    """Load a model from a local checkpoint"""
    print("Initializing model from config...")
    model = AutoModelForCausalLM.from_pretrained(
        checkpoint_dir, device_map="auto", torch_dtype=torch.float32
    )
    return model

def icl_random_one_skip_bigram_repetition_benchmark(
    model,
    tokenizer,
    num_of_samples=5000,
    seq_len=50,
    batch_size=64,
):

    model.eval()

    vocab_size = tokenizer.vocab_size
    random_sequence = torch.stack(
        [torch.randperm(vocab_size - 1)[:seq_len] + 1 for _ in range(num_of_samples)]
    )

    space = tokenizer.encode(" ")
    if len(space) == 1:
        space = space[0]
    else:
        raise ValueError("Space token is not a single token")

    new_seq_len = 2 * seq_len - 1
    random_interlaved_sequence = torch.full((num_of_samples, new_seq_len), space)

    random_interlaved_sequence[:, ::2] = random_sequence
    random_interlaved_repetitive_sequence = torch.cat(
        [random_interlaved_sequence, random_interlaved_sequence], dim=1
    )

    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for i in range(0, num_of_samples, batch_size):
            begin_index = i
            end_index = min(i + batch_size, num_of_samples)
            batch = random_interlaved_repetitive_sequence[begin_index:end_index, :]

            # Input is everything except the last token
            input_ids = batch[:, :-1].to(model.device)
            # Target is the last token
            target_ids = batch[:, -1].to(model.device)

            # Get model predictions
            outputs = model(input_ids=input_ids)
            logits = outputs.logits[:, -1, :]  # Last token logits for each sample
            predicted_ids = torch.argmax(logits, dim=-1)

            # Count correct predictions
            correct_predictions += (predicted_ids == target_ids).sum().item()
            total_predictions += target_ids.size(0)

    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0

    return accuracy

def repetition_tasks_benchmark(
    checkpoint_dir,
    tokenizer,
    step,
    random_repetition_seq_len=100,
    natural_text_repetition_seq_len=100,
    max_sample_size=5000,
    batch_size=64,
):
    model = load_local_checkpoint(checkpoint_dir)

    random_repetition_acc = random_sequence_repetition_accuracy(
        model,
        tokenizer,
        seq_len=random_repetition_seq_len,
        num_of_samples=max_sample_size,
        batch_size=batch_size,
    )
    dataset_name = "gutenberg"
    natural_text_repetition_acc = natural_text_repetition_accuracy(
        model,
        tokenizer,
        seq_len=natural_text_repetition_seq_len,
        num_of_samples=max_sample_size,
        batch_size=batch_size,
    )
    one_skip_bigram_repetition_acc = icl_random_one_skip_bigram_repetition_benchmark(
        model,
        tokenizer,
        seq_len=random_repetition_seq_len,
        num_of_samples=max_sample_size,
        batch_size=batch_size,
    )
    wandb.log(
        {
            f"random_repetition_accuracy(seq_len={random_repetition_seq_len})": random_repetition_acc,
        },
        step=step,
    )

    wandb.log(
        {
            f"{dataset_name}_repetition_accuracy(seq_len={natural_text_repetition_seq_len})": natural_text_repetition_acc,
        },
        step=step,
    )

    wandb.log(
        {
            f"one_skip_bigram_repetition_accuracy(seq_len={random_repetition_seq_len})": one_skip_bigram_repetition_acc,
        },
        step=step,
    )

    del model
    gc.collect()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--last_step", type=int, default=20000)
    parser.add_argument("--first_step", type=int, default=100)
    parser.add_argument(
        "--wandb_id",
        type=str,
        default=None,
        help="Wandb ID to use for logging (default: None)",
    )
    parser.add_argument(
        "--resume",
        type=str,
        default=None,
        help="Resume from a previous run (default: None)",
    )

    # accuracy sometimes higher when sequence length shorter
    parser.add_argument(
        "--random_repetition_seq_len",
        type=int,
        default=50,
        help="Sequence length for random repetition benchmark (default: 100)",
    )
    parser.add_argument(
        "--natural_text_repetition_seq_len",
        type=int,
        default=50,
        help="Sequence length for natural text repetition benchmark (default: 100)",
    )
    parser.add_argument(
        "--max_sample_size",
        type=int,
        default=5000,
        help="Sample size for random repetition benchmark (default: 5000)",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="Batch size for random repetition benchmark (default: 64)",
    )

    args = parser.parse_args()
    model_type = args.model_type
    first_step = args.first_step
    last_step = args.last_step
    wandb_id = args.wandb_id
    resume = args.resume
    batch_size = args.batch_size
    random_repetition_seq_len = args.random_repetition_seq_len
    natural_text_repetition_seq_len = args.natural_text_repetition_seq_len
    max_sample_size = args.max_sample_size

    tokenizer_name = "EleutherAI/pythia-160m"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.model_max_length = 2048
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    wandb.init(
        project="pythia_replicate_all_benchmark",
        id=wandb_id,
        resume=resume,
        name=model_type,
    )

    wandb.config.update(vars(args))

    for step in range(first_step, last_step, 100):
        checkpoint_dir = (
            f"~pythia_replicate/hf_output/{model_type}/step={step}"
        )

        repetition_tasks_benchmark(
            checkpoint_dir,
            tokenizer,
            step,
            random_repetition_seq_len,
            natural_text_repetition_seq_len,
            max_sample_size,
            batch_size,
        )
