import wandb
import gc
import torch
from transformers import AutoTokenizer
import argparse
import sys
sys.path.append("~pythia_replicate")
from lib.model_setup import load_local_checkpoint
from lib.fv import fv_icl_tasks_benchmark

def fv_tasks_benchmark(
    checkpoint_dir,
    tokenizer,
    step,
    task_names=None,
    num_of_shots=5,
    max_sample_size=5000,
    batch_size=64,
):
    model = load_local_checkpoint(checkpoint_dir)

    for task_name in task_names:
        icl_results = fv_icl_tasks_benchmark(
            model,
            tokenizer,
            task_name=task_name,
            num_of_shots=num_of_shots,
            max_samples=max_sample_size,
            batch_size=batch_size,
        )
        task_name = task_name.split("/")[-1]
        wandb.log(
            {
                f"{task_name}_accuracy(num_of_shots={num_of_shots})": icl_results[
                    task_name
                ],
            },
            step=step,
        )

    del model
    gc.collect()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--last_step", type=int, default=20000)
    parser.add_argument("--first_step", type=int, default=100)
    parser.add_argument(
        "--task_names",
        nargs="*",
        default=None,
        help="List of task names (e.g., --task_names extractive/choose_first_of_3)",
    )
    parser.add_argument(
        "--wandb_id",
        type=str,
        default=None,
        help="Wandb ID to use for logging (default: None)",
    )
    parser.add_argument(
        "--resume",
        type=str,
        default=None,
        help="Resume from a previous run (default: None)",
    )

    parser.add_argument(
        "--num_of_shots",
        type=int,
        default=5,
        help="Number of examples in few-shot prompts (default: 5)",
    )
    parser.add_argument(
        "--max_sample_size",
        type=int,
        default=5000,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="Batch size for processing (default: 64)",
    )

    args = parser.parse_args()
    model_type = args.model_type
    first_step = args.first_step
    last_step = args.last_step
    task_names = args.task_names
    num_of_shots = args.num_of_shots
    wandb_id = args.wandb_id
    resume = args.resume
    max_sample_size = args.max_sample_size
    batch_size = args.batch_size

    tokenizer_name = "EleutherAI/pythia-160m"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.model_max_length = 2048
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    wandb.init(
        project="pythia_replicate_all_benchmark",
        id=wandb_id,
        resume=resume,
        name=model_type,
    )

    wandb.config.update(vars(args))

    for step in range(first_step, last_step, 100):
        checkpoint_dir = (
            f"~pythia_replicate/hf_output/{model_type}/step={step}"
        )

        fv_tasks_benchmark(
            checkpoint_dir,
            tokenizer,
            step,
            task_names,
            num_of_shots,
            max_sample_size,
            batch_size,
        )
