# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
import sys
import os
import numpy as np
from pathlib import Path
from typing import Dict, List, Literal, Optional
from datetime import datetime
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
from lm_eval import base, evaluator, tasks
from lm_eval.base import BaseLM

from lit_llama.model import LLaMA, LLaMAConfig, Block
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

# from lit_gpt.gen import generate
from lit_llama.generate import generate
from lit_gpt import Tokenizer
# from lit_gpt.model_eval import GPT
from lit_gpt.utils import (
    get_default_supported_precision,
    gptq_quantization,
    load_checkpoint,
)
from lit_llama.utils import quantization
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from lightning.fabric.strategies import FSDPStrategy
from functools import partial

fsdp = False

class EvalHarnessBase(BaseLM):
    # Credits:
    # https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py
    def __init__(
        self, fabric: L.Fabric, model: LLaMA, tokenizer: Tokenizer, batch_size: int
    ):
        super().__init__()
        self.fabric = fabric
        self.model = model
        self.tokenizer = tokenizer
        self.batch_size_per_gpu = batch_size
        # with fabric.init_tensor():
        #     model.set_kv_cache(batch_size=batch_size)

    @classmethod
    def create_from_arg_string(cls, arg_string, additional_config=None):
        kwargs = {el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",")}
        return cls(**kwargs, **additional_config)

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_id

    @property
    def max_length(self):
        return self.model.config.block_size

    @property
    def vocab_size(self):
        return self.tokenizer.vocab_size

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu * self.fabric.world_size

    @property
    def device(self):
        return self.fabric.device

    def tok_encode(self, string: str) -> List[int]:
        return self.tokenizer.encode(string, bos=False, eos=False).tolist()

    def tok_decode(self, tokens: List[int]) -> str:
        t = torch.tensor(tokens)
        return self.tokenizer.decode(t)

    @torch.inference_mode()
    def _model_call(self, inps):
        return self.model(inps)

    @torch.inference_mode()
    def _model_generate(self, context, max_length, eos_token_id) -> torch.Tensor:
        # this only supports batch size 1
        assert context.shape[0] == 1
        out = generate(self.model, context[0], max_length, eos_id=eos_token_id)
        for block in self.model.transformer.h:
            block.attn.kv_cache.reset_parameters()
        return out.unsqueeze(0)

    @torch.inference_mode()
    def run_eval(
        self,
        eval_tasks: List[str],
        num_fewshot: int,
        limit: Optional[int],
        bootstrap_iters: int,
        no_cache: bool,
    ) -> Dict:
        # Returns a list containing all values of the task registry that
        # match at least one of the patterns
        import fnmatch

        def pattern_match(patterns, source_list):
            task_names = set()
            for pattern in patterns:
                for matching in fnmatch.filter(source_list, pattern):
                    task_names.add(matching)
            return list(task_names)

        eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS)
        print(f"Found tasks: {eval_tasks}")

        # **HACK INCOMING**:
        # first get task dict on local main rank
        # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading.
        # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache.
        if self.fabric.local_rank == 0:
            tasks.get_task_dict(eval_tasks)
        # torch barrier
        self.fabric.barrier()
        tasks.get_task_dict(eval_tasks)

        lm = self
        if not no_cache:
            lm = base.CachingLM(lm, "lm_cache/lit-gpt.db")

        results = evaluator.evaluate(
            lm=lm,
            task_dict=tasks.get_task_dict(eval_tasks),
            num_fewshot=num_fewshot,
            limit=limit,
            bootstrap_iters=bootstrap_iters,
        )
        results["config"] = dict(
            model="llama2-0.5B",
            batch_size=self.batch_size,
            device=str(self.device),
            num_fewshot=num_fewshot,
            limit=limit,
            bootstrap_iters=bootstrap_iters,
            no_cache=no_cache,
        )
        return results

def calculate_acc_mean_with_norm(results):
    acc_values = []
    
    for key, result in results.items():
        if "acc_norm" in result:
            acc_values.append(result["acc_norm"])
        else:
            acc_values.append(result["acc"])
    
    acc_mean = np.mean(acc_values)
    return acc_mean

@torch.inference_mode()
def run_eval_harness(
    checkpoint_dir: Path,
    tokenizer_dir: Path,
    model_name: str,
    precision: Optional[str] = None,
    quantize: Optional[
        Literal[
            "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"
        ]
    ] = None,
    eval_tasks: List[str] = ["arc_challenge", "piqa", "hellaswag", "hendrycksTest-*"],
    save_filepath: Optional[Path] = None,
    num_fewshot: int = 0,
    limit: Optional[int] = None,
    bootstrap_iters: int = 100000,
    no_cache: bool = True,
    batch_size: int = 1,
):
    if precision is None:
        precision = get_default_supported_precision(training=False)

    plugins = None
    if quantize is not None and quantize.startswith("bnb."):
        if "mixed" in precision:
            raise ValueError("Quantization and mixed precision is not supported.")
        dtype = {
            "16-true": torch.float16,
            "bf16-true": torch.bfloat16,
            "32-true": torch.float32,
        }[precision]
        plugins = BitsandbytesPrecision(quantize[4:], dtype)
        precision = None

    if fsdp:
        auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
        strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)
    # fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
    else:
        strategy = 'auto'
    fabric = L.Fabric(
        devices=1,
        num_nodes=1,
        strategy=strategy,
        accelerator="cuda",
        precision="bf16-mixed",
        # precision=precision,
    )
    fabric.launch()
    print(tokenizer_dir)
    tokenizer = Tokenizer(tokenizer_dir)

    size = model_name.removeprefix("llama-")
    config = LLaMAConfig.from_name(size)
    config.vocab_size = 50254
    config.padded_vocab_size = 50304
    config.block_size=1024
    fabric.print(f"config of llama2 {size}: {config}")
    fabric.print(f"Loading model with {config.__dict__}")
    # config = Config.from_name(model_name)

    print(
        f"Loading model {str(checkpoint_dir)!r} with {config.__dict__}",
        file=sys.stderr,
    )
    quantize = "gptq.int4"
    # with fabric.init_module(empty_init=True), quantization(mode=quantize):
    with fabric.device:
        model = LLaMA(config)

    model.eval()
    model = fabric.setup(model)

    print("***** ckpt_dir", checkpoint_dir)
    # load_checkpoint(fabric, model, checkpoint_dir)
    # optimizer = torch.optim.AdamW(
    #     model.parameters(),
    #     lr=1e-3,
    #     weight_decay=1e-3,
    #     betas=(0.9, 0.95),
    #     foreach=False,
    # )
    # optimizer = fabric.setup_optimizers(optimizer)
    # hparams = {
    # k: v
    # for k, v in locals().items()
    # if isinstance(v, (int, float, str)) and not k.startswith("_")
    # }
    # state = {
    #     "model": model,
    #     "optimizer": optimizer,
    #     "hparams": hparams,
    #     "iter_num": 0,
    #     "step_count": 0,
    # }
    fabric.load(checkpoint_dir, {"model": model})
    
    eval_harness = EvalHarnessBase(fabric, model, tokenizer, batch_size)

    results = eval_harness.run_eval(
        eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache
    )
    if save_filepath is None:
        print(results)
    else:
        print(f"Saving results to {str(save_filepath)!r}")
        data = json.dumps(results)
        print(results)
        with open(save_filepath, "w") as fw:
            fw.write(data)

        acc_mean_with_norm = calculate_acc_mean_with_norm(results['results'])
        print(f'Acc mean: {acc_mean_with_norm}')

if __name__ == "__main__":
    from jsonargparse import CLI

    torch.set_float32_matmul_precision("high")
    CLI(run_eval_harness, as_positional=False)
