from datasets import load_dataset
import json
from pathlib import Path
import time
from typing import List, Tuple, Any
import sys
import gc

import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from datasets import load_dataset
# import tqdm
from tqdm import tqdm

from argparse import ArgumentParser, Namespace
from model_loader import *



def string_match_part(preds, refs):
    score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100
    return round(score, 2)

def string_match_all(preds, refs):
    score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100
    return round(score, 2)


def evaluate_one_task(model, tokenizer, data, args, split,length):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    result = {}
    result["preds"] = []
    result["refs"] = []
    append_text = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
    numbers = min(len(data["input"] ), args.samples)
    pbar = tqdm(total=numbers, disable=False)

    for i in range(numbers):
        text = data["input"][i]
        ref = data["outputs"][i]
        # assert refs是不是一个list
        assert isinstance(ref, list), f"refs should be a list, but got {type(ref)}"

        inputs = tokenizer(text, truncation=True, padding=True, max_length=length, return_tensors="pt")
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)

        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                attention_mask = attention_mask,
                output_attentions=False,
                max_new_tokens=30,
                num_beams=1,
                #do_sample=False,
                temperature=0.7,
                eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
                pad_token_id=tokenizer.eos_token_id
            )
    
            output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True).strip()

        result["preds"].append(output)
        result["refs"].append(ref)

        output = None
        inputs = None
        input_ids = None
        attention_mask = None
        torch.cuda.empty_cache()
        gc.collect()

        pbar.update(1)

    pbar.close()

    if "qa" in split:
        score = string_match_part(result["preds"], result["refs"])
    else:
        score = string_match_all(result["preds"], result["refs"])
    print(f"Score on {split} with context length {length}: {score}")

    return score


def main(args):

    model = args.model
    tokenizer = AutoTokenizer.from_pretrained(model, model_max_length=sys.maxsize, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    loaded = load_model_and_apply_patches(model, args)
    results = []
    for context_length in args.context:
        print(f"Evaluating RULER with context length {context_length}...")
        dataset_name = f"SaylorTwift/RULER-{context_length}-llama-3.1-tokenizer-chat-template"

        for split in args.tasks:
            print(f"=========Task: {split}==========")
            data = load_dataset(dataset_name, split=split)

            if args.samples:
                data = data[:args.samples]

            torch.cuda.empty_cache()
            score = evaluate_one_task(loaded, tokenizer, data, args, split, context_length)
            results.append({"task": split, "context_length": context_length, "score": score})
            
            
            with open("log/ruler_yarn.jsonl", "w", encoding="utf8") as fout:
                for line in results:
                    fout.write(json.dumps(line, ensure_ascii=False) + "\n")
        
        print(f"Finished evaluation with context length {context_length}.")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--model", type=str, required=True,
                        help="Model name or path")
    parser.add_argument("--context", type=int, nargs="+", default=[8192, 16384, 32768, 65536, 131072])
    parser.add_argument("--tasks", type=str, nargs="+", default=["qa_1", "vt", "fwe", "niah_single_1", "niah_multikey_1", "niah_multiquery","cwe"])
    parser.add_argument("--samples", type=int, default=100)

    main(add_args(parser).parse_args())



        
        