
import os
from lm_eval.base import CacheHook
from lm_eval.models.gpt2 import GPT2LM
from lm_eval import tasks, evaluator, utils

import numpy as np
import math, os, datetime

import torch
from torch.nn import functional as F

RUN_TABLE = [1652] # part of model file name
RUN_MODEL_NAME = '/mnt/ssd-1/BlinkDL_dont_delete/B/TRAIN_100M/out/all-'

import sys

# eval_tasks=['winogrande']
# eval_tasks=['arc_easy']
# eval_tasks=['arc_challenge']
# eval_tasks=['piqa']
# eval_tasks=['hellaswag']
eval_tasks=['boolq']

USE_CUDA = True # True False
RUN_DEVICE = 'cuda' if USE_CUDA else 'cpu' # cpu cuda
######### Set RUN_DEVICE in src/model.py too !!!

from tqdm import tqdm
import torch
import torch.nn.functional as F

class EvalHarnessAdapter(GPT2LM):
    def __init__(self):
        pass
      
    def greedy_until(self, requests):
        raise NotImplementedError()

    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        res = []
        sum_logit = 0
        nCorrect = 0

        for COUNTER in range(len(requests)):
            n = COUNTER
            src = requests[n][1] + requests[n][2]
            
            sss = str(src)
            correct = True
            if sss in logitBuf:
                logit = logitBuf[sss]
                correct = correctBuf[sss]
            else:
                q_len = len(requests[n][1])
                logit = 0
                with torch.no_grad():
            
                    gpt_inputs['input_ids'] = torch.tensor([src], device=RUN_DEVICE)
                    gpt_inputs['attention_mask'] = torch.ones_like(gpt_inputs['input_ids'], device=RUN_DEVICE)
                    outputs = model(**gpt_inputs).logits[0]
                    
                    for i in range(q_len-1, len(src)-1):
                        oo = outputs[i]
                        dst = src[i+1]
                        logit += math.log(F.softmax(oo, dim=-1)[dst])
                        sorted_probs, s_index = torch.sort(oo, descending=True)
                        pred = s_index[0].item()
                        if pred != dst:
                            correct = False
                logitBuf[sss] = logit
                correctBuf[sss] = correct
            
            if correct:
                nCorrect += 1
            res += [(logit, correct)]
            sum_logit += logit
            mean = sum_logit / (COUNTER+1)
            acc = nCorrect / (COUNTER+1) * 100

            if n % 100 == 0:
                print(f'{n//100}/{len(requests)//100}', end = ' ', flush=True)
        return res

    @torch.no_grad()

    def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2):
        results = evaluator.evaluate(
            lm=self,
            task_dict=tasks.get_task_dict(eval_tasks),
            provide_description=False,
            num_fewshot=num_fewshot,
            limit=None,
            bootstrap_iters=bootstrap_iters,
        )
        return results

RWKV_ID = ''
for RUN_NUM in RUN_TABLE:
    RWKV_ID = RUN_NUM
    logitBuf = {}
    correctBuf = {}

    RWKV_FILENAME = RUN_MODEL_NAME + str(RUN_NUM)

    from transformers import AutoTokenizer

    model_name = "meta-llama/Llama-2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    from spike_quant.sparsity_utils import SrLlamaForCausalLM

    model = SrLlamaForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16).cuda()
    
    import main_utils,utils
    args = utils.parser_gen()
    args.rotate = True

    args.w_bits = 4

    args.a_bits = 4
    args.v_bits = 4
    args.k_bits = 4

    args.w_clip = True
    args.w_rtn = True

    # zero will be automatically adjusted if SPIKE_ON is enabled
    args.a_asym = True 
    args.k_asym = True
    args.v_asym = True

    model.act_wrapper()    
    model.eval()

    if args.w_bits<16:
        main_utils.add_weight_quantization(model,args)
    if args.a_bits<16:
        main_utils.add_input_quantization(model,args)

    gpt_inputs = tokenizer("This is a test", return_tensors="pt")
    
    print("Running evaluation harness...")
    adapter = EvalHarnessAdapter()
    adapter.tokenizer=tokenizer
    results = adapter.run_eval(
        eval_tasks=eval_tasks,
        bootstrap_iters=10000,
    )
    print(model_name)
    print(results)
  
log_path = vars().get("log_path", None)
os.makedirs(os.path.dirname(log_path), exist_ok=True)

with open(log_path, "a") as f:
    f.write(model_name + "\n")
    f.write(str(results) + "\n")
    f.write(f"args {args}\n")