import collections
import itertools
import random

import efficiency_benchmark.dependencies.lm_eval.base
import efficiency_benchmark.dependencies.lm_eval.metrics
import efficiency_benchmark.dependencies.lm_eval.models
import efficiency_benchmark.dependencies.lm_eval.tasks
import numpy as np
from efficiency_benchmark.dependencies.lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
)


@positional_deprecated
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
):
    
    random.seed(1234)
    np.random.seed(1234)

    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
        if model_args is None:
            model_args = ""
        lm = efficiency_benchmark.dependencies.lm_eval.models.get_model(model).create_from_arg_string(
            model_args, {"batch_size": batch_size, "device": device}
        )
    else:
        assert isinstance(model, efficiency_benchmark.dependencies.lm_eval.base.LM)
        lm = model

    if not no_cache:
        lm = efficiency_benchmark.dependencies.lm_eval.base.CachingLM(
            lm,
            "lm_cache/" + model + "_" + model_args.replace("=", "-").replace(",", "_").replace("/", "-") + ".db",
        )

    task_dict = efficiency_benchmark.dependencies.lm_eval.tasks.get_task_dict(tasks)

    if check_integrity:
        run_task_tests(task_list=tasks)

    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
        bootstrap_iters=bootstrap_iters,
        description_dict=description_dict,
        decontamination_ngrams_path=decontamination_ngrams_path,
    )

    
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
        "bootstrap_iters": bootstrap_iters,
        "description_dict": description_dict,
    }

    return results


decontaminate_suffix = "_decontaminate"


@positional_deprecated
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
):
    
    

    
    assert not provide_description  
    if provide_description is not None:
        
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )

    decontaminate = decontamination_ngrams_path is not None

    task_dict_items = [
        (name, task) for name, task in task_dict.items() if (task.has_validation_docs() or task.has_test_docs())
    ]

    results = collections.defaultdict(dict)
    versions = collections.defaultdict(dict)

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

    overlaps = collections.defaultdict(list)  

    
    
    
    

    
    docs = {}

    docs_for_decontamination = collections.defaultdict(list)

    
    for task_name, task in task_dict_items:
        versions[task_name] = task.VERSION
        
        
        if task.has_test_docs():
            task_doc_func = task.test_docs
            task_set = "test"  
        elif task.has_validation_docs():
            task_set = "val"  
            task_doc_func = task.validation_docs
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")

        
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
        rnd.shuffle(task_docs)

        description = description_dict[task_name] if description_dict and task_name in description_dict else ""

        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
            if decontaminate and task.should_decontaminate():
                docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))

            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description)
            reqs = task.construct_requests(doc, ctx)
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
            for i, req in enumerate(reqs):
                requests[req.request_type].append(req)
                
                
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))

    
    if decontaminate:
        from efficiency_benchmark.dependencies.lm_eval.decontamination.decontaminate import (
            get_train_overlap,
        )

        print("Finding train/test overlap, please wait...")
        overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)

    
    process_res_queue = collections.defaultdict(list)

    
    for reqtype, reqs in requests.items():
        
        
        
        

        print("Running", reqtype, "requests")
        resps = getattr(lm, reqtype)([req.args for req in reqs])
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))

    vals = collections.defaultdict(list)

    
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)

            
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)

    
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        real_metric = metric  
        if metric.endswith(decontaminate_suffix):
            real_metric = metric.replace(decontaminate_suffix, "")  
        results[task_name][metric] = task.aggregation()[real_metric](items)

        
        

        stderr = efficiency_benchmark.dependencies.lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[real_metric],
            bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
        )

        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)

    return {"results": dict(results), "versions": dict(versions)}


def make_table(result_dict):
    
    from pytablewriter import LatexTableWriter, MarkdownTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
            if m.endswith("_stderr"):
                continue

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
            else:
                values.append([k, version, m, "%.4f" % v, "", ""])
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    
    

    return md_writer.dumps()
