# -*- coding:utf-8 -*-


from ast import arg
import math
from sqlite3 import paramstyle
from tkinter.tix import Tree
from typing import List, Optional, Tuple

import argparse
from tqdm import tqdm
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from insert_needle import get_input_ctx_multi, get_config
import re


def get_output(fail_cnt, ctx_len):
    text_inputs = get_input_ctx_multi(tokenizer=tokenizer, ctx_len=ctx_len, last_words=last_words,
                                      needles=needles, meta_prompt=args.meta)

    inputs = tokenizer(text_inputs, return_tensors="pt", return_token_type_ids=False).to(model.device)
    prompt_length = inputs.input_ids.size()[-1]
    sample = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
    output = tokenizer.decode(sample[0][prompt_length:])
    output = " ".join(output.split())
    score = 0
    for i, ans in enumerate(expected_answer):
        if ans.lower() in output.lower():
            score += 1
        else:
            fail_cnt[i] += 1
    score = (score / len(expected_answer)) * 100
    return output, score, prompt_length


def main():
    file_name = os.path.join(pred_save_path, f"dup.jsonl")
    fw = open(file_name, "a")
    fail_cnt = [0, 0, 0, 0]
    fw.write("--------------------- <New RUN> -----------------------\n")
    scores = []
    save_ds = []
    max_len = max_length - max_new_tokens
    more_than_2 = []
    more_than_4 = []

    for i in range(args.num_test):
        save_d = {}
        output, score, prompt_length = get_output(fail_cnt, max_len)

        save_d = {}
        print(f"----------------- sample {i} -----------------")
        print('document len', prompt_length)
        print(output)
        print("score:", score)
        if score >= 50:
            more_than_2.append(1.0)
        else:
            more_than_2.append(0.0)

        if score >= 100:
            more_than_4.append(1.0)
        else:
            more_than_4.append(0.0)
        scores.append(score)
        print(f"step {i}, ctx len {prompt_length}, avg score {sum(scores) / len(scores)}")
        print(f"step {i}, ctx len {prompt_length}, avg score {sum(scores) / len(scores)}", file=fw)

        print(f">=2 retrieved {sum(more_than_2) / len(more_than_2)}")
        print(f">=2 retrieved {sum(more_than_2) / len(more_than_2)}", file=fw)

        print(f">=4 retrieved {sum(more_than_4) / len(more_than_4)}")
        print(f">=4 retrieved {sum(more_than_4) / len(more_than_4)}", file=fw)

        print(f"fail cnt {fail_cnt}")
        print(f"fail cnt {fail_cnt}", file=fw)
        print("-" * 20)
        fw.flush()
        save_d["ctx_len"] = prompt_length
        save_d["pred"] = output
        save_d["needle"] = expected_answer
        save_d["score"] = score
        save_ds.append(save_d)

    for save_d in save_ds:
        fw.write(json.dumps(save_d) + '\n')
    fw.write(f"avg:{sum(scores) / len(scores)}\n")
    fw.close()

    # break


if __name__ == "__main__":
    max_new_tokens = 128
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_test', default=500, type=int)
    parser.add_argument('--max_length', default=2048, type=int)
    parser.add_argument('--eval_set', default="number-4")
    parser.add_argument('--meta', default=True, type=eval)
    parser.add_argument('--S', default=0.33, type=float)
    parser.add_argument('--W', default=128, type=int)
    parser.add_argument('--model_path', type=str, default="")
    args = parser.parse_args()
    last_words, needles, expected_answer = get_config(args.eval_set, args.model_path)
    model_path = args.model_path
    max_length = args.max_length
    half_len = max_length // 2

    if model_path[-1] == "/":
        model_path = model_path[:-1]
    open_source_model = model_path.split("/")[-1]

    pred_save_path = f"results/{open_source_model}/{args.eval_set}"
    print(f"Your prediction file will be saved to: {pred_save_path}  , press enter to confirm...")
    os.makedirs(pred_save_path, exist_ok=True)

    c1_len = int(max_length * args.S)
    c2_begin = args.W
    from string_monkey_patch import replace_with_string, replace_rope_init

    replace_rope_init()
    replace_with_string(args.max_length, c1_len=c1_len, max_position=max_length, c2_begin=args.c2_begin, f=8)

    if "mistral" in model_path:
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
    config = AutoConfig.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, config=config, attn_implementation="flash_attention_2",
                                                 device_map="auto",
                                                 trust_remote_code=True, torch_dtype=torch.bfloat16)
    model = model.eval()

    sys.exit(main())
