"""
Experimenting with pre-defining model results data structures, and collation methods for each type of model request. 
These are to ensure consistency of the data format going into metrics computation, and instance level predictions output.

New model classes or metric types may require us to spawn new results data structures or collation methods. 
"""

from typing import Any, List

from oe_eval.components.instances import RequestInstance


def collate_results(instances: List[RequestInstance], results: List[Any]) -> list:
    """
    Collates the model output results with the input instances. For examples:
    In sampling generations, when we generate num_return_sequences > 1 for each instance, we collate them together with the instance.
    In multiple choice tasks, when we get output tokens and logits for each choice of an instance, we collate them together with the instanc.

    :param instances: list
        Input instances built from a list of construct_requests(), or the (TBI) build_all_requests().
    :param results: list
        The model generate_until output results.
    """
    if instances[0].request_type == "generate_until":
        num_return_sequences = instances[0].request.generation_kwargs.get("num_return_sequences", 1)  # type: ignore
        duplicate_ins = [ins for ins in instances for _ in range(num_return_sequences)]
    else:
        duplicate_ins = instances
    assert len(duplicate_ins) == len(results)
    # The LeanModel generate_until does not rearrange instance orders so we can just zip them to collate
    results_for_requests = []
    for i, (ins, res) in enumerate(zip(duplicate_ins, results)):
        res1 = {"res_id": i}
        res1.update(ins.to_dict())
        res1["model_resps"] = res
        results_for_requests.append(res1)
    return results_for_requests
