import argparse
import copy
import datetime
import json
import logging
import os
import shutil
import tempfile
import time
from collections import defaultdict
from typing import Optional
from pathlib import Path

from lm_eval.api.model import TemplateLM

from oe_eval.components.instances import RequestInstance
from oe_eval.default_configs import MODEL_DEFAULTS, TASK_DEFAULTS
from oe_eval.models.eleuther_huggingface import HFLM_Verbose
from oe_eval.models.eleuther_vllm_causallms import VLLM_Verbose
# from oe_eval.models.litellm import LiteLLM
from oe_eval.tasks.aggregate_tasks import add_aggregate_tasks

# from oe_eval.tasks import task_adapters
from oe_eval.tasks.base_task import Task
from oe_eval.tasks.chat_templates import CHAT_TEMPLATES
from oe_eval.tasks.eleuther_task import EleutherTask
from oe_eval.launch import resolve_task_suite
from oe_eval.configs.tasks import TASK_CONFIGS

# from oe_eval.tasks.eleuther_tasks import TASK_REGISTRY as ELEUTHER_TASK_REGISTRY
from oe_eval.tasks.oe_eval_tasks import TASK_REGISTRY
from oe_eval.utilities.gsheet_writing import write_metrics_to_gsheet
from oe_eval.utilities.hf_hub_writing import upload_to_hf
from oe_eval.utilities.model_results_collation import collate_results
from oe_eval.utilities.remote_utils import cache_s3_folder, upload_directory, download_file_from_s3
from oe_eval.utilities.wandb_writing import wandb_log_metrics
from oe_eval.utils import (
    get_dict_with_defaults,
    get_recorded_inputs,
    hash_dict,
    load_json,
    load_jsonl,
    make_int,
    parse_args_string,
    remove_nested_nones,
    remove_none_values,
    save_json,
    save_jsonl,
    show_model_input,
)

from src.config import ExperimentConfig
from modules.retriever.offline_retrieval import OfflineRetrieval

# To run "hf-oldstyle" OLMo models
try:
    from hf_olmo import *  # noqa: F403
except ImportError:
    pass

# For running e.g. XXXX
try:
    from open_lm.hf import *  # noqa: F403
except ImportError:
    pass

logger = logging.getLogger()


def add_arg(parser, arg, defaults, **kwargs):
    if defaults is None:
        default = None
    else:
        default = defaults.get(arg.replace("-", "_"))
    parser.add_argument(f"--{arg}", default=default, **kwargs)


_parser = argparse.ArgumentParser()
## Model configs
add_arg(_parser, "model", MODEL_DEFAULTS, type=str, help="Name of model in HF hub")
add_arg(_parser, "revision", MODEL_DEFAULTS, type=str, help="Revision of model in HF hub")
add_arg(_parser, "trust-remote-code", MODEL_DEFAULTS, type=bool, help="Trust remote code in HF hub")
add_arg(_parser, "max-length", MODEL_DEFAULTS, type=int, help="Max length of model input")
add_arg(_parser, "model-type", MODEL_DEFAULTS, type=str, help="Model type (e.g., 'hf' or 'vllm')")
add_arg(
    _parser,
    "model-path",
    MODEL_DEFAULTS,
    type=str,
    help="Override model path for local model loading",
)
_parser.add_argument(
    "--model-args",
    type=str,
    default=None,
    help="Dict with extra model args for model instantiation",
)

_parser.add_argument(
    "--retrieval_config", 
    type=Path,
    help="Path to attack config file", 
    default=None
)

_parser.add_argument(
    "--retrieval_results_path",
    type=str,
    help="Path to attack config file",
    default=""
)

_parser.add_argument(
    "--ctx_key", 
    type=str,
    help="retrieval text", 
    default="ctxs"
)

_parser.add_argument(
    "--matching_key", 
    type=str,
    help="retrieval text", 
    default="query"
)

_parser.add_argument(
    "--retrieval_text_key", 
    type=str,
    help="retrieval text", 
    default="retrieval text"
)

_parser.add_argument(
    "--k", 
    type=int,
    help="retrieval text", 
    default=3
)

_parser.add_argument(
    "--rerank_k", 
    type=int,
    help="retrieval text", 
    default=100
)

_parser.add_argument(
    "--presort_key", 
    type=str,
    help="Preliminary scoring key to first-stage sort retrieval contexts", 
    default=None
)

_parser.add_argument(
    "--sort_key", 
    type=str,
    help="Main scoring key to sort retrieval contexts", 
    default="retrieval score"
)

_parser.add_argument(
    "--threshold", 
    type=float,
    help="Determine which documents under specific score threshold (given by sort_key) to filter out", 
    default=None
)

_parser.add_argument(
    "--sources_to_keep", 
    type=str,
    help="Determine which documents from specified sources to keep", 
    default=None
)

_parser.add_argument(
    "--sources_to_filter", 
    type=str,
    help="Determine which documents from specified sources to filter out", 
    default=None
)

_parser.add_argument(
    "--llm_only", action="store_true", help="no retrieval"
)

_parser.add_argument(
    "--no_thinking", action="store_true", help="Disable thinking when applicable"
)

_parser.add_argument(
    "--task",
    type=str,
    nargs="+",
    required=True,
    help="Task spec(s), or name(s) as in the Task registry",
)
add_arg(
    _parser,
    "limit",
    None,
    type=float,
    help="Max number (or fraction) of instances for a task",
)
add_arg(_parser, "split", None, type=str, help="split from which to pull eval docs")
add_arg(
    _parser,
    "random-subsample-seed",
    None,
    type=int,
    help="seed for doc sampling when `limit` is smaller than docs size",
)
_parser.add_argument("--num-shots", type=int, default=None, help="Number of examples in prompt")
_parser.add_argument(
    "--fewshot-seed", type=int, default=None, help="seed for fewshot example sampling"
)
_parser.add_argument(
    "--batch-size",
    type=str,
    default=1,
    help="batch size of requests for model processing (can be 'auto')",
)
_parser.add_argument(
    "--max-batch-size", type=int, default=32, help="Max batch size of to try for auto batch-size"
)
_parser.add_argument("--output-dir", type=str, default=None, help="Directory for output files")
_parser.add_argument(
    "--cached-output-dir", type=str, default=None, help="Directory for cached output files"
)
_parser.add_argument(
    "--remote-output-dir", type=str, default=None, help="Remote directory for output"
)
_parser.add_argument(
    "--num-recorded-inputs",
    type=int,
    default=1000,
    help="Number of sample model inputs in full output, for sanity checks",
)
_parser.add_argument("--gsheet", type=str, help="Name of Google Sheet for writing results")
_parser.add_argument(
    "--hf-save-dir", type=str, help="HF dataset path for saving results: <repo-name>//<directory>"
)
_parser.add_argument(
    "--wandb-run-path",
    type=str,
    help="Wandb run: <entity>/<project>/runs/<run_id> (e.g., ai2-llm/open-instruct/runs/mbz172ia)",
    default=None,
)
_parser.add_argument(
    "--save-raw-requests", type=bool, default=False, help="Save raw requests in output directory"
)
_parser.add_argument(
    "--recompute-metrics", action="store_true", help="Recompute metrics for cached predictions"
)


def convert_chat_instance(model, ins, chat_template=None, no_thinking=False):
    if not isinstance(ins.request.context, str):
        print("converting")
        if chat_template and chat_template not in CHAT_TEMPLATES:
            raise ValueError("Chat template {chat_template} not recognized!")
        elif chat_template is None and model.tokenizer.chat_template is None:
            raise ValueError("Tokenizer does not have chat_template!")
        messages = ins.request.context["messages"]
        assistant_prefix = ins.request.context.get("assistant_prefix", "")
        if chat_template:
            # We only use templates that don't rely on tokenizer
            context = CHAT_TEMPLATES[chat_template](messages, tokenizer=None)
        else:
            if no_thinking:
                print("no think!")
                context = model.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
                )
                print(context)
            else:
                context = model.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
        ins.request.context = context + assistant_prefix
        if not assistant_prefix and hasattr(ins.request, "continuation"):
            # Strip leading space if no assistant prefix
            ins.request.continuation = ins.request.continuation.lstrip()
    return ins


def evaluate(model, instances):
    instances_types = defaultdict(list)

    for ins in instances:
        instances_types[ins.request_type].append(ins)

    results_for_requests = []
    if len(instances_types["loglikelihood"]) > 0:
        requests_logll = [ins.request for ins in instances_types["loglikelihood"]]
        output = model.loglikelihood_verbose(requests=requests_logll)
        results_for_requests += collate_results(instances_types["loglikelihood"], output)

    if len(instances_types["generate_until"]) > 0:
        requests_generate = [ins.request for ins in instances_types["generate_until"]]
        output = model.generate_until_verbose(requests=requests_generate)
        results_for_requests += collate_results(instances_types["generate_until"], output)

    return results_for_requests


def process_eval_args(args_dict: dict) -> dict:
    """Process and validate the arguments for the evaluator"""
    model_config = {}

    ## model configs
    model_config["model"] = args_dict.pop("model")
    model_config["revision"] = args_dict.pop("revision")
    model_config["trust_remote_code"] = args_dict.pop("trust_remote_code")
    model_config["max_length"] = args_dict.pop("max_length")
    model_config["model_path"] = args_dict.pop("model_path")
    model_config["model_type"] = args_dict.pop("model_type")
    model_args_dict = parse_args_string(args_dict.pop("model_args"))
    if model_args_dict:
        keys = list(model_args_dict.keys())
        for key in keys:
            if key in model_config:
                if model_config[key] is not None and model_config[key] != MODEL_DEFAULTS.get(key):
                    raise ValueError(f"Model arg {key} is set both globally and in model_args")
                else:
                    model_config[key] = model_args_dict[key]
                    del model_args_dict[key]
        model_config.update(model_args_dict)

    # retrieval configs
    retrieval_config = args_dict.pop("retrieval_config", None)
    llm_only = args_dict.pop("llm_only")
    offline_retrieval_config = {}
    for key in ["retrieval_results_path", "retrieval_text_key", "ctx_key", "matching_key", "k", "rerank_k", "sources_to_keep", "sources_to_filter", "presort_key", "sort_key", "threshold"]:
        offline_retrieval_config[key] = args_dict.pop(key, None)

    if offline_retrieval_config["retrieval_results_path"].startswith("s3://"):
        download_file_from_s3(offline_retrieval_config["retrieval_results_path"], destination_file="/tmp/retrieval_results.jsonl")
        offline_retrieval_config["retrieval_results_path"] = "/tmp/retrieval_results.jsonl"

    if retrieval_config is not None:
        retrieval_config_instance = ExperimentConfig.load(retrieval_config, drop_extra_fields=False)
        
        for key in ["fewshot_seed", "random_subsample_seed"]:
            if key not in args_dict or args_dict[key] is None:
                args_dict[key] = retrieval_config_instance.seed
        
        retrieval_config = {}
        retrieval_config['online_retrieval_config'] = retrieval_config_instance.retrieval_config.online_retrieval
        retrieval_config['offline_retrieval_config'] = retrieval_config_instance.retrieval_config.offline_retrieval
        retrieval_config['k'] = retrieval_config_instance.retrieval_config.k
        
        
        for key in ["retrieval_results_path", "retrieval_text_key", "ctx_key", "matching_key", "sources_to_keep", "sources_to_filter", "presort_key", "sort_key", "threshold", "rerank_k"]:
            if offline_retrieval_config[key] is not None:
                # retrieval_config['offline_retrieval_config'][key] = offline_retrieval_config[key]
                setattr(retrieval_config['offline_retrieval_config'], key, offline_retrieval_config[key])
        for key in ["k"]:
            if offline_retrieval_config[key] is not None:
                # retrieval_config['offline_retrieval_config'][key] = offline_retrieval_config[key]
                retrieval_config[key] = offline_retrieval_config[key]
        retrieval_config["llm_only"] = llm_only
        print(f"Retrieval config: {retrieval_config}")

    ## task configs
    # task_config_shared: they can be set either globally (through --<arg>) or
    # individually (through --task {'<arg>':} for each task).
    task_config_shared = {}
    if "limit" in args_dict:
        args_dict["limit"] = make_int(args_dict["limit"])
    for key in ["limit", "split", "num_shots", "fewshot_seed", "random_subsample_seed"]:
        value = args_dict.pop(key, None)
        if value is not None:
            task_config_shared[key] = value
    if model_config.get("chat_model"):
        # Default to use chat format if chat_model is True
        task_config_shared["use_chat_format"] = True

    task_config_shared["no_thinking"] = args_dict.pop("no_thinking")
    
    tasks = args_dict.pop("task")
    # Borrowed from launch.py
    task_configs = []
    all_tasks = []
    task_suite_parent: dict = {}
    for task in tasks:
        all_tasks += resolve_task_suite(task, task_suite_parent)
    for task in all_tasks:
        if task.endswith(".jsonl"):
            task_configs += load_jsonl(task)
        elif task in TASK_CONFIGS:
            task_config = copy.deepcopy(TASK_CONFIGS[task])
            if "metadata" not in task_config:
                task_config["metadata"] = {}
            task_config["metadata"]["alias"] = task
            task_configs.append(task_config)
        elif task in task_suite_parent:
            logger.warning(
                f"No config found for task: {task} (from task suite {task_suite_parent[task]})"
            )
            task_configs.append({"task_name": task})
        else:
            task_config = parse_args_string(task, "task_name")
            # Allow updates to existing configured tasks
            if task_config["task_name"] in TASK_CONFIGS:
                new_task_config = task_config
                task_config = copy.deepcopy(TASK_CONFIGS[new_task_config["task_name"]])
                del new_task_config["task_name"]
                task_config.update(new_task_config)
            elif len(task_config) == 1:
                logger.warning(f"No config found for task: {task}, using raw task")
            task_configs.append(task_config)
        task_configs = [get_dict_with_defaults(task_config_shared, task_config) for task_config in task_configs]

    # task_configs = []
    
    # for task in tasks:
    #     task_config = parse_args_string(task, "task_name")
    #     task_configs.append(get_dict_with_defaults(task_config, task_config_shared))

    compute_config = {}
    compute_config["batch_size"] = args_dict.pop("batch_size")
    compute_config["max_batch_size"] = args_dict.pop("max_batch_size")
    compute_config["output_dir"] = args_dict.pop("output_dir")
    compute_config["cached_output_dir"] = args_dict.pop("cached_output_dir")
    compute_config["remote_output_dir"] = args_dict.pop("remote_output_dir")
    compute_config["num_recorded_inputs"] = args_dict.pop("num_recorded_inputs")
    compute_config["gsheet"] = args_dict.pop("gsheet")
    compute_config["hf_save_dir"] = args_dict.pop("hf_save_dir")
    compute_config["save_raw_requests"] = args_dict.pop("save_raw_requests")
    compute_config["recompute_metrics"] = args_dict.pop("recompute_metrics")
    compute_config["wandb_run_path"] = args_dict.pop("wandb_run_path")

    if compute_config["remote_output_dir"] is not None:
        # if a remote directory is provided, we set the output dir to a temporary directory
        # we can then upload the results from.
        compute_config["output_dir"] = compute_config["output_dir"] or tempfile.mkdtemp()

    if args_dict:
        raise ValueError(f"Unrecognized arguments: {args_dict}")
    eval_config = {
        "model_config": model_config,
        "tasks_config": task_configs,
        "compute_config": compute_config,
        "retrieval_config": retrieval_config
    }
    return eval_config


def load_model(model_load_config: dict) -> HFLM_Verbose:
    """Load the model"""
    # TODO:  better validation and documentation of other args
    model_path = model_load_config["model_path"]
    if model_path is not None and isinstance(model_path, str):
        model_path = [model_path]
    if model_path is not None and not isinstance(model_path, str):
        chosen = None
        for path in model_path:
            if path.startswith("beaker://"):
                raise ValueError(
                    "Beaker model paths not supported for local model loading, use launch.py"
                )
            if path.startswith("hf://"):
                chosen = path.replace("hf://", "")
                break
            if path.startswith("s3://"):
                cached = None
                try:
                    cached = cache_s3_folder(path)
                except Exception as exc:
                    logger.warning(f"Something went wrong downloading from S3 {path}: {exc}")
                if cached is not None:
                    chosen = cached["local_dir"]
                    break
            if path.startswith("weka://"):
                chosen = path.replace("weka://", "/weka-mount/")
                if not os.path.exists(chosen):
                    # TODO: Fall back to S3 interface to WEKA here
                    chosen = None
                else:
                    break

            if os.path.exists(path):
                chosen = path
                break
        if chosen is None:
            chosen = model_path[0]  # Probably won't end well, could raise error here?
        model_path = chosen

    pretrained = model_path or model_load_config["model"]
    model_type = model_load_config["model_type"]
    model_load_config_other = model_load_config.copy()
    for key in [
        "chat_model",
        "chat_template",
        "model",
        "model_path",
        "model_type",
        "metadata",
    ]:  # Keys not passed on to model loader
        model_load_config_other.pop(key, None)
    if model_type == "hf":
        model_class: TemplateLM = HFLM_Verbose
    elif model_type == "vllm":
        model_class = VLLM_Verbose
    # elif model_type == "litellm":
    #     model_class = LiteLLM
    else:
        raise ValueError(f"Model type {model_type} not recognized")
    model = model_class(
        pretrained=pretrained,
        **model_load_config_other,
    )
    return model


def load_task(task_config: dict, output_dir: Optional[str] = None) -> Task:
    """From individual task_spec (default recipe + adjustment at runtime) and global task args for experiment, construct a new Task object

    Config args from task specs can include:
    - generation_kwargs: dict, generation params to overwrite the generation task defaults: ex. repeats, do_sample, temperature...
    - context_kwargs: dict, params to overwrite the context building defaults (mostly for MultipleChoiceTaskParam): ex. prompt_style, question_prefix ...
    - metric_kwargs: dict, metric params for the task to overwrite the metric defaults: ex. k in pass@k...
    - misc. args: ex. `native_id_field` to help the Eleuther MC task adapter identify a unique `native_id` for the task dataset instances.
    """
    task_args = copy.deepcopy(task_config)
    task_name = task_args.pop("task_name")
    if task_name in TASK_REGISTRY:
        task_obj: Task = TASK_REGISTRY[task_name](task_name=task_name, task_config=task_args)
        task_obj_name = task_obj.task_config["task_name"]
        if task_obj_name != task_name:
            raise ValueError(
                f"Task name {task_name} does not match task_config task_name {task_obj_name}"
            )
    elif task_name.startswith("eleuther"):
        logger.warning(f"Using Eleuther 0.4 task adapter for task {task_name}, might be buggy!")
        task_obj = EleutherTask(task_name=task_name, task_config=task_args)
    # elif task_name in ELEUTHER_TASK_REGISTRY:
    #    # TODO: to avoid misuse, add ELEUTHER_MCTASK_REGISTRY that lists out all Eleuther tasks that could use the adapter
    #    task_obj = task_adapters.adapt_eleuther_mc_task(task_name)(
    #        task_name=task_name, task_config=task_args
    #    )
    else:
        raise ValueError(f"Task {task_name} not found in the task registry!")
    task_obj._task_hash = hash_dict(task_obj.task_config, TASK_DEFAULTS)
    task_obj._output_dir = output_dir
    return task_obj


def task_file_name(output_dir: str, task_idx: int, task_name: str, file_name: str) -> str:
    return os.path.join(output_dir, f"task-{task_idx:03d}-{task_name}-{file_name}")


def run_eval(args_dict: dict):
    eval_config = process_eval_args(args_dict)
    model_config = eval_config["model_config"]
    compute_config = eval_config["compute_config"]
    tasks_config = eval_config["tasks_config"]
    retrieval_config = eval_config["retrieval_config"]
    
    # Setup retrieval
    offline_retrieval = None
    if retrieval_config is not None and retrieval_config['offline_retrieval_config']:
        offline_retrieval = OfflineRetrieval(
                retrieval_config=retrieval_config['offline_retrieval_config'],
                k=retrieval_config['k'])

    task_objects = [
        load_task(task_config, compute_config["output_dir"]) for task_config in tasks_config
    ]
    if not task_objects:
        raise ValueError("No valid tasks constructed for the experiment!")

    model_hash = hash_dict(model_config, MODEL_DEFAULTS)

    output_dir = compute_config["output_dir"]
    cached_output_dir = compute_config["cached_output_dir"]
    recompute_metrics = compute_config["recompute_metrics"]
    have_all_predictions = False
    if compute_config["recompute_metrics"]:
        # Check if we have all predictions, then we can skip loading model
        have_all_predictions = True
        for task_idx, task in enumerate(task_objects):
            task_name = task.task_name
            if output_dir and os.path.exists(
                task_file_name(output_dir, task_idx, task_name, "predictions.jsonl")
            ):
                continue
            if cached_output_dir and os.path.exists(
                task_file_name(cached_output_dir, task_idx, task_name, "predictions.jsonl")
            ):
                continue
            have_all_predictions = False
            break

    eval_model = None
    if model_config["model"] is None and model_config["model_path"] is None:
        logger.info("No model specified, only generating requests...")
    elif have_all_predictions:
        logger.info("All predictions found, skipping model loading...")
    else:
        model_load_config = model_config.copy()
        model_load_config["batch_size"] = compute_config["batch_size"]
        model_load_config["max_batch_size"] = compute_config["max_batch_size"]
        eval_model = load_model(model_load_config)
        # logger.info(f"Loaded model config: {eval_model.model.config}")
    logger.info(f"Model loaded. Model hash: {model_hash['hash']}")

    metrics_output_file = None
    remote_output_dir = compute_config["remote_output_dir"]

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        metrics_output_file = os.path.join(compute_config["output_dir"], "metrics-all.jsonl")
    else:
        logger.warning("WARNING: No output directory specified, not output files will be written!")

    # Run tasks through the model and process results
    metrics_all = []
    beaker_env_variables = {
        k: v
        for k, v in os.environ.items()
        if "GIT_REF" in k or "GIT_BRANCH_NAME" in k or "BEAKER" in k and "token" not in k.lower()
    }

    for task_idx, task in enumerate(task_objects):
        start_time = time.time()
        task_name = task.task_name
        predictions_file = None
        cached_predictions = None
        # Move task files from cache directory if need be
        if cached_output_dir is not None and output_dir is not None:
            for file_name in ["predictions.jsonl", "metrics.json", "requests.jsonl"]:
                existing_file = task_file_name(output_dir, task_idx, task_name, file_name)
                if os.path.exists(existing_file):
                    continue
                cached_file = task_file_name(cached_output_dir, task_idx, task_name, file_name)
                if os.path.exists(cached_file):
                    shutil.copy(
                        cached_file, task_file_name(output_dir, task_idx, task_name, file_name)
                    )
        if output_dir is not None:
            predictions_file = task_file_name(output_dir, task_idx, task_name, "predictions.jsonl")
            output_file = task_file_name(output_dir, task_idx, task_name, "metrics.json")
            if (
                os.path.exists(predictions_file)
                and os.path.exists(output_file)
                and not recompute_metrics
            ):
                logger.info(f"Task num {task_idx} {task_name} already processed, skipping...")
                metrics = load_json(output_file)
                metrics["task_idx"] = task_idx
                metrics_all.append(metrics)
                continue
            if os.path.exists(predictions_file):
                cached_predictions = load_jsonl(predictions_file)
                logger.info(f"Using existing predictions for Task num {task_idx} {task_name}")

        task_config = task.task_config
        # Here task.task_config has already combined runtime task args with task-specific defaults.
        task_hash = task._task_hash
        logger.info(f"**** Processing task: {task_name} ({task_hash['non_default']}) ****")
        logger.info(f"  Task hash: {task_hash['hash']}, starting download...")
        task.download()
        logger.info("  Task finished downloading!")
        if task_name.startswith("eleuther:") and task_config.get("use_chat_format"):
            chat_template = model_config.get("chat_template")
            if chat_template:
                # We only use templates that don't rely on tokenizer
                chat_template_fn = lambda x: CHAT_TEMPLATES[chat_template](x, None)  # type: ignore # noqa: E731
            else:
                chat_template_fn = lambda x: eval_model.tokenizer.apply_chat_template(  # type: ignore # noqa: E731
                    x, tokenize=False, add_generation_prompt=True
                )  # type: ignore
            task._chat_template_fn = chat_template_fn  # type: ignore
        eval_requests_raw = None
        if output_dir is not None:
            requests_output_file = task_file_name(
                compute_config["output_dir"], task_idx, task_name, "requests.jsonl"
            )
            if os.path.exists(requests_output_file) and cached_predictions is not None:
                logger.info(f"Using existing requests from {requests_output_file}...")
                eval_requests_raw = load_jsonl(requests_output_file)
        if eval_requests_raw:
            task_instances: list[RequestInstance] = [
                RequestInstance.from_dict(ins) for ins in eval_requests_raw
            ]
            task._instances = task_instances
        else:
            # Load in offline retriever (TODO: online retriever)
            task.build_all_requests(offline_retriever=offline_retrieval, llm_only=retrieval_config is not None and retrieval_config['llm_only'])
            task_instances = task._instances if task._instances is not None else []
            eval_requests_raw = [ins.to_dict() for ins in task_instances]
        logger.info(f"Number of requests: {len(eval_requests_raw)}")
        if len(eval_requests_raw) == 0:
            logger.error("No requests found this task! Skipping model processing...")
            continue
        if (
            compute_config["save_raw_requests"]
            and output_dir is not None
            and not os.path.exists(requests_output_file)
        ):
            logger.info(f"Saving raw requests in {requests_output_file}...")
            save_jsonl(requests_output_file, eval_requests_raw)
        logger.info(f"First request raw: {eval_requests_raw[0]}")

        processing_time = time.time() - start_time  # Only used if no model is involved
        full_config = {
            "task_name": task_name,
            "task_hash": task_hash["hash"],
            "model_hash": model_hash["hash"],
            "model_config": model_config,
            "task_config": task_config,
            "compute_config": remove_none_values(compute_config),
            "processing_time": processing_time,
            "current_date": datetime.datetime.now(tz=datetime.timezone.utc).strftime(
                "%Y-%m-%d %H:%M:%S UTC"
            ),
            "num_instances": 0,  # TODO: Add this for no model case later
            "beaker_info": beaker_env_variables,
        }

        if eval_model is None and cached_predictions is None:
            logger.info("Model not specified, skipping model processing...")
            if output_dir is not None:
                logger.info(f"Saving task config in {output_file}...")
                save_json(output_file, full_config)
            continue

        # Super hacky to avoid reprocessing metrics for uncond cases
        unsafe_prediction_import = False
        if cached_predictions is not None:
            # TODO: This is pretty fragile if we change response structure, especially handling of "idx" field
            cached_predictions_dict = {ins["doc_id"]: ins for ins in cached_predictions}
            results_for_requests = []
            for request in eval_requests_raw:
                if request["doc_id"] not in cached_predictions_dict:
                    unsafe_prediction_import = True
                    continue
                doc_predictions = cached_predictions_dict[request["doc_id"]]
                predictions = doc_predictions["model_output"]
                prediction_idx = request["idx"]
                if prediction_idx >= len(predictions):
                    # We'll allow fewer cached predictions than request idx, for cases
                    # where repeats>1 generative predictions have been filtered
                    continue
                results_for_request = request.copy()
                results_for_request["model_resps"] = predictions[request["idx"]]
                results_for_requests.append(results_for_request)
        else:
            # Run Model
            # Convert chat requests to model-specific input strings:
            if task_config.get("use_chat_format") and model_config["model_type"] != "litellm":
                for ins in task_instances:
                    convert_chat_instance(eval_model, ins, model_config.get("chat_template"), task_config.get("no_thinking"))
            # Log requests associated with first instance
            first_requests = eval_requests_raw[0].copy()
            del first_requests["request"]
            first_requests["requests"] = [
                r for r in eval_requests_raw if r["doc_id"] == first_requests["doc_id"]
            ]
            logger.info("\n".join(show_model_input(first_requests)))
            # TODO: Make sure these are in the right order when mixing loglikelihood and generate_until
            logger.info(f"First inputs: {task._instances[0]}")
            results_for_requests = evaluate(model=eval_model, instances=task._instances)
            logger.info(f"First results: {results_for_requests[0]}")

        # Post-process generation results for metrics
        task.make_metrics()
        if task._metrics:
            # single metric per task first (TBI) multiple metrics per task
            metric = task._metrics[0]
        else:
            raise ValueError("No metrics found for task!")

        if unsafe_prediction_import:  # TODO: Fix/refactor this madness
            metric._scores_for_docs = cached_predictions
        else:
            metric.compute_for_docs(results_for_requests)

        metrics_raw = metric.aggregate_to_task(primary_metric=task_config.get("primary_metric"))

        processing_time = time.time() - start_time
        full_config["processing_time"] = processing_time
        full_config["num_instances"] = len(metric._scores_for_docs)

        # Collect full config for output files
        # TODO: Add beaker id to compute_config
        for key in ["context_kwargs", "generation_kwargs", "metric_kwargs"]:
            if task_config[key] == {}:
                task_config[key] = None  # return empty kwargs to None

        metrics = full_config.copy()
        metrics["metrics"] = metrics_raw
        metrics["task_idx"] = task_idx
        metrics_all.append(metrics)

        # append experiment and task args and create doc/instance level outputs
        predictions_raw = copy.deepcopy(metric._scores_for_docs)

        recorded_inputs = []
        if compute_config["num_recorded_inputs"]:
            recorded_inputs = get_recorded_inputs(
                eval_requests_raw,
                predictions_raw,
                compute_config["num_recorded_inputs"],
                log_first=False,
            )

        # Add task and model hashes to each prediction
        for prediction in predictions_raw:
            prediction["task_hash"] = task_hash["hash"]
            prediction["model_hash"] = model_hash["hash"]

        logger.info(f"First instance output: {predictions_raw[0]}")

        if output_dir is not None:
            logger.info(f"Saving predictions in {predictions_file}...")
            save_jsonl(predictions_file, predictions_raw)
            logger.info(f"Saving metrics in {output_file}...")
            save_json(output_file, metrics)
            if recorded_inputs:
                recorded_inputs_file = task_file_name(
                    compute_config["output_dir"], task_idx, task_name, "recorded-inputs.jsonl"
                )
                logger.info(f"Saving recorded inputs in {recorded_inputs_file}...")
                save_jsonl(recorded_inputs_file, recorded_inputs)

        if compute_config["gsheet"] is not None:
            logger.info(f"Writing to gsheet {compute_config['gsheet']}...")
            gsheet_res = write_metrics_to_gsheet(metrics, compute_config["gsheet"], TASK_DEFAULTS)
            logger.info(f"Data written to gsheet: {gsheet_res}")

        logger.info(f"\n\n** Task metrics: **\n{metrics}")

    if eval_model is None and not recompute_metrics:
        # We're done if no model was involved
        return

    # meta tasks processing logics - evals scores are stored within metric instances,
    # and metric instances are embedded under tasks
    # TODO: Make this more general
    aggregate_tasks = add_aggregate_tasks(metrics_all)
    for metrics in aggregate_tasks:
        logger.info(f"Aggregated task metrics: {metrics}")
        if compute_config["gsheet"] is not None:
            logger.info(f"Writing to gsheet {compute_config['gsheet']}...")
            gsheet_res = write_metrics_to_gsheet(metrics, compute_config["gsheet"], TASK_DEFAULTS)
            logger.info(f"Data written to gsheet: {gsheet_res}")

    metrics_all = aggregate_tasks + metrics_all  # Put aggregate tasks in front

    primary_score_summary = []
    for m in metrics_all:
        task_name = m["task_config"].get("metadata", {}).get("alias", m["task_name"])
        score = m["metrics"].get("primary_score", "NA")
        primary_score_summary.append(f"{task_name}: {score:.6}")
    logger.info("Summary of primary scores: \n" + "\n".join(primary_score_summary))

    if compute_config["wandb_run_path"] is not None:
        wandb_log_metrics(compute_config["wandb_run_path"], metrics_all)

    if output_dir is not None:
        logger.info(f"Saving final metrics in {metrics_output_file}...")
        save_jsonl(metrics_output_file, metrics_all)
        beaker_metrics_file = os.path.join(compute_config["output_dir"], "metrics.json")
        metrics_simple = [
            {
                "task": m["task_name"],
                **m["metrics"],
                "num_instances": m["num_instances"],
                "task_config": m["task_config"],
            }
            for m in metrics_all
        ]
        logger.info(f"Saving beaker metrics.json so in {beaker_metrics_file}...")
        # Put summary with the lowest alphabetical key, so it appears first in beaker
        save_json(
            beaker_metrics_file,
            {
                "all_primary_scores": primary_score_summary,
                "metrics": metrics_simple,
                "model_config": model_hash["non_default"],
            },
        )

        if compute_config["hf_save_dir"] is not None:
            hf_paths = compute_config["hf_save_dir"].split("//")
            filename = f'{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}-{tasks_config[0]["task_name"]}-{len(tasks_config)}-tasks.jsonl'
            target_path = os.path.join(hf_paths[1], filename)
            upload_to_hf(metrics_output_file, hf_paths[0], target_path)

    if remote_output_dir is not None:
        logger.info(f"Uploading results to remote directory {remote_output_dir}...")
        upload_directory(output_dir, remote_output_dir)


if __name__ == "__main__":
    args = _parser.parse_args()
    args_dict = vars(args)
    run_eval(args_dict)
