import wandb
import gc
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
import json
import os
import glob
import random
import torch.nn.functional as F
import lm_eval
import random
from lm_eval.models.huggingface import HFLM


def load_local_checkpoint(checkpoint_dir):
    """Load a model from a local checkpoint"""
    print("Initializing model from config...")
    model = AutoModelForCausalLM.from_pretrained(
        checkpoint_dir, device_map="auto", torch_dtype=torch.float32
    )
    return model


def external_benchmark(
    checkpoint_dir,
    tokenizer_name,
    device="cuda:0",
    batch_size=64,
    step=None,
    num_of_shots=0,
):
    model = HFLM(
        pretrained=checkpoint_dir,
        tokenizer=tokenizer_name,
        device=device,
        batch_size=batch_size,
        trust_remote_code=True,
    )

    try:
        results = lm_eval.simple_evaluate(
            model=model,
            tasks=[
                "arc_easy",
                "lambada_openai",
                "blimp_regular_plural_subject_verb_agreement_1",
                "blimp_principle_A_domain_3",
                "blimp_distractor_agreement_relative_clause",
                "blimp_determiner_noun_agreement_with_adjective_1",
                "blimp_matrix_question_npi_licensor_present",
            ],
            num_fewshot=num_of_shots,
        )

        accuracy = results["results"]["arc_easy"]["acc,none"]
        wandb.log(
            {
                f"arc_accuracy(num_of_shots={num_of_shots})": accuracy,
            },
            step=step,
        )

        accuracy = results["results"]["lambada_openai"]["acc,none"]
        wandb.log(
            {
                f"lambada_accuracy(num_of_shots={num_of_shots})": accuracy,
            },
            step=step,
        )

        accuracy = results["results"]["blimp_regular_plural_subject_verb_agreement_1"][
            "acc,none"
        ]
        wandb.log(
            {
                f"blimp_regular_plural_subject_verb_agreement_1_accuracy": accuracy,
            },
            step=step,
        )

        accuracy = results["results"]["blimp_principle_A_domain_3"]["acc,none"]
        wandb.log(
            {
                f"blimp_principle_A_domain_3_accuracy": accuracy,
            },
            step=step,
        )

        accuracy = results["results"]["blimp_distractor_agreement_relative_clause"][
            "acc,none"
        ]
        wandb.log(
            {
                f"blimp_distractor_agreement_relative_clause_accuracy": accuracy,
            },
            step=step,
        )

        accuracy = results["results"][
            "blimp_determiner_noun_agreement_with_adjective_1"
        ]["acc,none"]
        wandb.log(
            {
                f"blimp_determiner_noun_agreement_with_adjective_1_accuracy": accuracy,
            },
            step=step,
        )

        accuracy = results["results"]["blimp_matrix_question_npi_licensor_present"][
            "acc,none"
        ]
        wandb.log(
            {
                f"blimp_matrix_question_npi_licensor_present_accuracy": accuracy,
            },
            step=step,
        )
    except Exception as e:
        print(f"Error in external benchmark: {e}")
        return

    del model
    gc.collect()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--last_step", type=int, default=20000)
    parser.add_argument("--first_step", type=int, default=100)

    parser.add_argument(
        "--wandb_id",
        type=str,
        default=None,
        help="Wandb ID to use for logging (default: None)",
    )
    parser.add_argument(
        "--resume",
        type=str,
        default=None,
        help="Resume from a previous run (default: None)",
    )
    parser.add_argument(
        "--num_of_shots",
        type=int,
        default=0,
        help="Number of examples in few-shot prompts (default: 5)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0",
        help="Device to use for evaluation (default: cuda:1)",
    )

    args = parser.parse_args()
    model_type = args.model_type
    last_step = args.last_step
    num_of_shots = args.num_of_shots
    device = args.device
    wandb_id = args.wandb_id
    resume = args.resume

    tokenizer_name = "EleutherAI/pythia-160m"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    wandb.init(
        project="pythia_replicate_all_benchmark",
        id=wandb_id,
        resume=resume,
        name=model_type,
    )

    wandb.config.update(vars(args))

    for step in range(args.first_step, last_step, 100):
        checkpoint_dir = (
            f"~pythia_replicate/hf_output/{model_type}/step={step}"
        )

        external_benchmark(
            checkpoint_dir, tokenizer_name, device=device, batch_size=16, step=step
        )
