from asyncore import write
import os
import random
import re
import time
import numpy as np
import json
import torch
from lm_eval.evaluator import make_table
import collections
import itertools
import random

import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
from lm_eval.utils import positional_deprecated, run_task_tests

import numpy as np
import transformers
from lm_eval.evaluator import evaluate


def simple_evaluate(
        model,
        model_args=None,
        tasks=[],
        num_fewshot=0,
        batch_size=None,
        max_batch_size=None,
        device=None,
        no_cache=True,
        limit=None,
        bootstrap_iters=100000,
        description_dict=None,
        check_integrity=False,
        decontamination_ngrams_path=None,
        write_out=False,
        output_base_path=None,
        lm=None  ##changed by wenhua
):
    """Instantiate and evaluate a model on a list of tasks.

    :param model: Union[str, LM]
        Name of model, transformers.PreTrainedModel object, or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
        String arguments for each model class, see LM.create_from_arg_string.
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int or str, optional
        Batch size for model
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
    :param device: str, optional
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
    :param no_cache: bool
        Whether or not to cache
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
    :param description_dict: dict[str, str]
        Dictionary of custom task descriptions of the form: `task_name: description`
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
    :param write_out: bool
        If True, write details about prompts and logits to json for all tasks
    :param output_base_path: str, optional
        Directory to which detailed eval info will be written. Defaults to present working dir.
    :return
        Dictionary of results
    """

    random.seed(1234)
    np.random.seed(1234)

    assert tasks != [], "No tasks specified"
    if lm == None:
        if isinstance(model, str):
            if model_args is None:
                model_args = ""
            lm = lm_eval.models.get_model(model).create_from_arg_string(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )
        elif isinstance(model, transformers.PreTrainedModel):
            lm = lm_eval.models.get_model("hf-causal")(
                pretrained=model,
                batch_size=batch_size,
                max_batch_size=max_batch_size,
            )
            no_cache = True
        else:
            assert isinstance(model, lm_eval.base.LM)
            lm = model

        if not no_cache:
            lm = lm_eval.base.CachingLM(
                lm,
                "lm_cache/"
                + (model if isinstance(model, str) else model.model.config._name_or_path)
                + "_"
                + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
                + ".db",
            )
        if isinstance(lm.tokenizer, transformers.LlamaTokenizerFast):
            if lm.tokenizer.pad_token is None:
                lm.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            else:
                lm.tokenizer.pad_token = '[PAD]'

    task_dict = lm_eval.tasks.get_task_dict(tasks)
    if re.search("llama", lm.model.config.model_type):
        for key, value in task_dict.items():
            if key == "lambada_openai":
                from lambada import LambadaOpenAI
                task_dict[key] = LambadaOpenAI()
            if key == "lambada_standard":
                from lambada import LambadaStandard
                task_dict[key] = LambadaStandard()
    if check_integrity:
        run_task_tests(task_list=tasks)

    # if isinstance(lm.tokenizer, transformers.LlamaTokenizerFast) and not lm.tokenizer.pad_token:
    #     lm.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
        bootstrap_iters=bootstrap_iters,
        description_dict=description_dict,
        decontamination_ngrams_path=decontamination_ngrams_path,
    )

    # add info about the model and few shot config
    model_name = None
    if isinstance(model, str):
        model_name = model
    elif isinstance(model, transformers.PreTrainedModel):
        model_name = "pretrained=" + model.config._name_or_path
    results["config"] = {
        "model": model_name,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "batch_sizes": list(lm.batch_sizes.values())
        if hasattr(lm, "batch_sizes")
        else [],
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
        "bootstrap_iters": bootstrap_iters,
        "description_dict": description_dict,
    }

    return results, lm


def eval_model(output_dir, seqlen, model=None, tokenizer=None,
               tasks=["lambada_openai", "hellaswag", "winogrande", "piqa"], eval_bs=32,
               use_accelerate=True, eval_orig_float=True, limit=None, device="cuda:0"):
    if output_dir is None:
        output_dir = "./tmp_signround"

    if os.path.exists(output_dir) and not eval_orig_float:
        import shutil
        shutil.rmtree(output_dir)
    if not eval_orig_float:
        # model = model.to(torch.float16)
        model = model.to("cpu")
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
    dtype = 'float16'
    if (hasattr(model, 'config') and model.config.torch_dtype is torch.bfloat16):
        dtype = 'bfloat16'

    model_args = f'pretrained={output_dir},tokenizer="{output_dir}",dtype={dtype},use_accelerate={use_accelerate}'
    model_type = "hf-causal"
    results, lm = simple_evaluate(model=model_type,
                                  model_args=model_args,
                                  tasks=tasks,
                                  device=device,
                                  batch_size=eval_bs,
                                  limit=limit)
    print(make_table(results))

    datasets = ['wikitext2', 'ptb-new', 'c4-new']

    from utils import get_loaders, eval_ppl_same_with_gptq
    import transformers
    lm.model.seqlen = seqlen

    tokenizer = transformers.AutoTokenizer.from_pretrained(output_dir, use_fast=False, trust_remote_code=True)##bloom ppl has issue

    for dataset in datasets:
        dataloader, testloader = get_loaders(
            dataset, seed=0, seqlen=seqlen, tokenizer=tokenizer,
        )
        print(dataset, flush=True)
        ppl = eval_ppl_same_with_gptq(lm.model, testloader, device)
        print(dataset, ppl)


if __name__ == "__main__":
    import time

    s = time.time()
    test_tasks = [
        'winogrande', 'lambada_openai'
        # 'boolq', 'rte',
        #   'arc_easy', 'arc_challenge', 'hendrycksTest-*', 'wikitext',
    ]
    eval_model(output_dir="/models/llama-7b-hf", seqlen=2048,
               tasks=test_tasks,
               eval_bs=32, eval_orig_float=True, limit=None)

    print("cost time: ", time.time() - s)
    print("please check seqlen is right or not")
