# -*- coding:utf-8 -*-
import importlib
from re import I
import re
import yaml
import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers
import sys
import json
from utils import read_manifest
from tqdm import tqdm
from datetime import datetime



def main(infer_size):
    fw = open(file_name, "a")
    scores = []
    save_ds = []
    for i, data in enumerate(tqdm(data_list)):
        if infer_size != -1 and i >= infer_size:
            break
        text_inputs = data["input"]
        inputs = tokenizer(text_inputs, return_tensors="pt", return_token_type_ids=False).to(model.device)
        prompt_length = inputs.input_ids.size()[-1]
        sample = model.generate(**inputs, repetition_penalty=1, do_sample=False, max_new_tokens=max_new_tokens)
        output = tokenizer.decode(sample[0][prompt_length:])
        output = " ".join(output.split())
        save_d = {}
        ref = data["outputs"]
        print(f"----------------- sample {i} -----------------")
        print('[Model Prediction]',output)
        print('[Ground Truth]', ref)
        if "qa" in args.task:
            score = max([r.lower() in output.lower() for r in ref])
        else:
            score_curr = [1.0 if r.lower() in output.lower() else 0.0 for r in ref]
            score = sum(score_curr) / len(score_curr)
        print("[score]:", score)
        scores.append(score)
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        fw.write(f"{current_time}, step {i}, ctx len {prompt_length}, avg score {sum(scores) / len(scores)}, \n")
        # print(f"===== step {i}, ctx len {prompt_length}, avg score {sum(scores) / len(scores)} =====")
        # print(f"step {i}, ctx len {prompt_length}, avg score {sum(scores) / len(scores)}", file=fw)
        fw.flush()
        save_d["ctx_len"] = prompt_length
        save_d["pred"] = output
        save_d["needle"] = ref
        save_d["score"] = score
        save_ds.append(save_d)

    for save_d in save_ds:
        fw.write(json.dumps(save_d) + '\n')
    fw.write(f"avg_score:{sum(scores) / len(scores)}\n")
    fw.close()
    print(f"avg:{sum(scores) / len(scores)}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument("--data_dir", type=str, help='path to load the dataset jsonl files')
    parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]')
    parser.add_argument("--task", type=str, help='Options: tasks in benchmark')
    parser.add_argument('--pretraining_length', type=int, default=-1, help='context length during evaluation')
    parser.add_argument('--window_size', type=int, default=-1, help='window size for selfextend')
    parser.add_argument('--detect_low', type=int, default=-1, help='window size for selfextend')
    parser.add_argument('--detect_mid', type=int, default=-1, help='window size for selfextend')
    parser.add_argument('--detect_high', type=int, default=-1, help='window size for selfextend')
    parser.add_argument('--group_sizes', type=str, default='', help='group size for selfextend')
    parser.add_argument('--group_size', type=int, default=-1, help='group size for selfextend')
    parser.add_argument('--group_size_low', type=int, default=-1, help='low dim group size for DPE')
    parser.add_argument('--group_size_mid', type=int, default=-1, help='mid dim group size for DPE')
    parser.add_argument('--group_size_high', type=int, default=-1, help='high dim group size for DPE')
    parser.add_argument('--group_size_all', type=int, default=-1, help='default all group size for DPE topk')
    parser.add_argument('--max_new_tokens', type=int, default=-1, help='topk for DPE')
    parser.add_argument('--topk', type=int, default=-1, help='topk for DPE')
    parser.add_argument('--ntk_scale_factor', type=int, default=1, help='topk for DPE')
    parser.add_argument('--all_dims', type=list, default=[0,8,16,24,32,40,48,56,64])
    parser.add_argument('--infer_size', type=int, default=-1)
    parser.add_argument('--low_dim', type=int, default=-1)
    parser.add_argument('--high_dim', type=int, default=-1)
    parser.add_argument('--seq_len', type=int, default=-1, help='context length during evaluation')
    parser.add_argument('--model_name', type=str, default="llama", help='model name')
    args = parser.parse_args()

    test_max_length = int(re.search(r'\d{4,}', args.data_dir).group())
    print("the maximum input length is ", test_max_length)

    # copied from https://github.com/hsiehjackson/RULER/blob/main/scripts/data/synthetic/constants.py#L24
    if "vt" in args.task :
        max_new_tokens = 30
    elif "cwe" in args.task:
        max_new_tokens = 120
    elif "fwe" in args.task:
        max_new_tokens = 50
    elif "qa" in args.task:
        max_new_tokens = 32
    elif "niah"  in args.task:
        max_new_tokens = 128
    else:
        raise NotImplementedError("Unsupported task")

    if args.max_new_tokens != -1:
        max_new_tokens = args.max_new_tokens

    model_path = args.model_path
    open_source_model = args.model_name


    print("*" * 10, "Data loading", "*"*10)
    if args.data_dir is None:
        args.data_dir = f"jsonl_data/{args.task}/{open_source_model}-{test_max_length}.jsonl"
        print("data dir is not specified. We load from", args.data_dir)
    curr_folder = os.path.dirname(os.path.abspath(__file__))


    with open(os.path.join(curr_folder, f"{args.benchmark}.yaml"), "r") as f:
        tasks_customized = yaml.safe_load(f)
        if args.task not in tasks_customized:
            raise ValueError(f'{args.task} is not found in config_tasks.yaml')

    task_file = args.data_dir
    data_list = read_manifest(task_file)
    print("*" * 10, "loading ends..", "*"*10)

    pred_save_path = f"Predictions/{args.task}/{open_source_model}/"
    print(f"Your prediction file will be saved to: {pred_save_path}.")
    os.makedirs(pred_save_path, exist_ok=True)

    # =======
    # if args.model_name != "llama3.1-8b-Instruct-chunkllama":
    config = transformers.AutoConfig.from_pretrained(
        args.model_path,
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.model_path,
    )
    from models.string_for_llama import causal_forward
    ### baseline
    if args.model_name == "llama3-8b-Instruct" or args.model_name == "Llama-2-7b-chat-hf" or args.model_name == "llama3.1-8b-Instruct" or args.model_name == "Llama-3.1-70B-Instruct" or args.model_name == "Llama-3-8B-Instruct-Gradient-1048k":
        print("=====llama-baseline=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}.jsonl")
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Qwen2.5-7B-Instruct" or args.model_name == "Qwen2-72B-Instruct":
        print("=====qwen-baseline=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}.jsonl")
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Mistral-7B-Instruct-v0.2" or args.model_name == "Mistral-7B-Instruct-v0.3":
        print("=====mistral-baseline=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}.jsonl")
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Phi-3-medium-128k-instruct":
        print("=====Phi-baseline=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}.jsonl")
        from transformers import Phi3ForCausalLM
        Phi3ForCausalLM.forward = causal_forward
        model = Phi3ForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )

    ### ntk-dynamic
    elif args.model_name == "llama3-8b-Instruct-ntk-dynamic":
        print("=====llama-ntk=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        config.rope_scaling = {}
        config.rope_scaling["factor"] = args.ntk_scale_factor
        config.rope_scaling["rope_type"] = "dynamic"

        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "llama3.1-8b-Instruct-ntk-dynamic" or args.model_name == "Llama-3.1-70B-Instruct-ntk-dynamic":
        print("=====llama3.1-ntk=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        from models.llama31_ntk import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Qwen2.5-7B-Instruct-ntk-dynamic" or args.model_name == "Qwen2-72B-Instruct-ntk-dynamic":
        print("=====qwen-ntk=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        config.rope_scaling = {}
        config.rope_scaling["factor"] = args.ntk_scale_factor
        config.rope_scaling["rope_type"] = "dynamic"

        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Mistral-7B-Instruct-v0.2-ntk-dynamic" or args.model_name == "Mistral-7B-Instruct-v0.3-ntk-dynamic":
        print("=====mistral-ntk=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        from models.mistral import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )

    ### yarn
    elif args.model_name == "llama3-8b-Instruct-yarn":
        print("=====llama3-yarn=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        config.rope_scaling = {}
        config.rope_scaling["factor"] = args.ntk_scale_factor
        import math
        config.rope_scaling["attention_factor"] = 0.1 * math.log(4) + 1.0
        config.rope_scaling["rope_type"] = "yarn"

        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "llama3.1-8b-Instruct-yarn" or args.model_name == "Llama-3.1-70B-Instruct-yarn":
        print("=====llama3.1-yarn=====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        from models.llama31_yarn import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    
    elif args.model_name == "Qwen2.5-7B-Instruct-yarn" or args.model_name == "Qwen2-72B-Instruct-yarn":
        print("===== qwen-yarn =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        config.rope_scaling = {}
        config.rope_scaling["factor"] = args.ntk_scale_factor
        import math
        config.rope_scaling["attention_factor"] = 0.1 * math.log(args.ntk_scale_factor) + 1.0
        config.rope_scaling["rope_type"] = "yarn"

        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Mistral-7B-Instruct-v0.2-yarn" or args.model_name == "Mistral-7B-Instruct-v0.3-yarn":
        print("===== mistral-yarn =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_factor{args.ntk_scale_factor}.jsonl")

        config.rope_scaling = {}
        config.rope_scaling["factor"] = args.ntk_scale_factor
        import math
        config.rope_scaling["attention_factor"] = 0.1 * math.log(2) + 1.0
        config.rope_scaling["rope_type"] = "yarn"

        from models.mistral_yarn import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    

    ### chunkllama
    elif args.model_name == "llama3-8b-Instruct-chunkllama":
        print("===== llama-chunkllama =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_pl{args.pretraining_length}.jsonl")

        from models.chunkllama_attn_replace import replace_with_chunkllama
        replace_with_chunkllama(pretraining_length=args.pretraining_length)
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "llama3.1-8b-Instruct-chunkllama" or args.model_name == "Llama-3.1-70B-Instruct-chunkllama":
        print("===== llama3.1-chunkllama =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_pl{args.pretraining_length}.jsonl")
        # from models.llama31_ntk.configuration_llama import LlamaConfig
        # config = LlamaConfig.from_pretrained(
        #     args.model_path,
        # )
        from models.chunkllama_attn_replace import replace_with_chunkllama3_1
        replace_with_chunkllama3_1(pretraining_length=args.pretraining_length)
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Qwen2.5-7B-Instruct-chunkllama" or args.model_name == "Qwen2-72B-Instruct-chunkllama":
        print("===== qwen-chunkllama =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_pl{args.pretraining_length}.jsonl")

        from models.chunkqwen_attn_replace import replace_with_chunkqwen
        replace_with_chunkqwen(pretraining_length=args.pretraining_length)
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )
    elif args.model_name == "Mistral-7B-Instruct-v0.2-chunkllama" or args.model_name == "Mistral-7B-Instruct-v0.3-chunkllama":
        print("===== mistral-chunkllama =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_pl{args.pretraining_length}.jsonl")

        from models.chunkllama_attn_replace import replace_with_chunkmistral
        replace_with_chunkmistral(pretraining_length=args.pretraining_length)
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(
            args.model_path,
            config=config,
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            device_map="auto",
        )

    ### rerope
    elif args.model_name == "llama3-8b-Instruct-rerope" or args.model_name == "llama3.1-8b-Instruct-rerope" or args.model_name == "Llama-3.1-70B-Instruct-rerope":
        print("===== llama-rerope =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}.jsonl")
        from models.llama_rerope import SelfExtend
        window_size = args.window_size
        group_size = args.group_size
        print(f'selfextend config: using group size {group_size}, window size {window_size}')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    elif args.model_name == "Qwen2.5-7B-Instruct-rerope" or args.model_name == "Qwen2-72B-Instruct-rerope":
        print("===== qwen-rerope =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}.jsonl")
        from models.llama_rerope import SelfExtend
        window_size = args.window_size
        group_size = args.group_size
        print(f'selfextend config: using group size {group_size}, window size {window_size}')
        use_flash = True
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    elif args.model_name == "Mistral-7B-Instruct-v0.2-rerope" or args.model_name == "Mistral-7B-Instruct-v0.3-rerope":
        print("===== mistral-rerope =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}.jsonl")
        from models.llama_rerope import SelfExtend
        window_size = args.window_size
        group_size = args.group_size
        print(f'selfextend config: using group size {group_size}, window size {window_size}')
        use_flash = True
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    
    ### selfextend
    elif args.model_name == "llama3-8b-Instruct-selfextend" or args.model_name == "llama3.1-8b-Instruct-selfextend" or args.model_name == "Llama-3.1-70B-Instruct-selfextend":
        print("===== llama-selfextend =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_g{args.group_size}.jsonl")
        from models.llama_selfextend import SelfExtend
        window_size = args.window_size
        group_size = args.group_size
        print(f'selfextend config: using group size {group_size}, window size {window_size}')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    elif args.model_name == "Qwen2.5-7B-Instruct-selfextend" or args.model_name == "Qwen2-72B-Instruct-selfextend":
        print("===== qwen-selfextend =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_g{args.group_size}.jsonl")
        from models.llama_selfextend import SelfExtend
        window_size = args.window_size
        group_size = args.group_size
        print(f'selfextend config: using group size {group_size}, window size {window_size}')
        use_flash = True
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    elif args.model_name == "Mistral-7B-Instruct-v0.2-selfextend" or args.model_name == "Mistral-7B-Instruct-v0.3-selfextend":
        print("===== mistral-selfextend =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_g{args.group_size}.jsonl")
        from models.llama_selfextend import SelfExtend
        window_size = args.window_size
        group_size = args.group_size
        print(f'selfextend config: using group size {group_size}, window size {window_size}')
        use_flash = True
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    
    ### DPE
    elif args.model_name == "llama3-8b-Instruct-DPE" or args.model_name == "llama3.1-8b-Instruct-DPE" or args.model_name == "Llama-3.1-70B-Instruct-DPE":
        print("===== llama-DPE =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_gl{args.group_size_low}_gm{args.group_size_mid}_gh{args.group_size_high}_low{args.low_dim}_high{args.high_dim}.jsonl")
        from models.llama_dim_selfextend import SelfExtend
        window_size = args.window_size
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, dim_range, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    elif args.model_name == "Qwen2.5-7B-Instruct-DPE" or args.model_name == "Qwen2-72B-Instruct-DPE":
        print("===== qwen-DPE =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_gl{args.group_size_low}_gm{args.group_size_mid}_gh{args.group_size_high}_low{args.low_dim}_high{args.high_dim}.jsonl")
        from models.llama_dim_selfextend import SelfExtend
        window_size = args.window_size
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        use_flash = True
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, dim_range, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    elif args.model_name == "Mistral-7B-Instruct-v0.2-DPE" or args.model_name == "Mistral-7B-Instruct-v0.3-DPE":
        print("===== mistral-DPE =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_gl{args.group_size_low}_gm{args.group_size_mid}_gh{args.group_size_high}_low{args.low_dim}_high{args.high_dim}.jsonl")
        from models.llama_dim_selfextend import SelfExtend
        window_size = args.window_size
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        use_flash = True
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, dim_range, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"
    
    ### DPE-by-dim
    elif args.model_name == "llama3-8b-Instruct-DPE-by-dim" or args.model_name == "llama3.1-8b-Instruct-DPE-by-dim" or args.model_name == "Llama-3.1-70B-Instruct-DPE-by-dim":
        print("===== llama-DPE-by-dim =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    elif args.model_name == "Qwen2.5-7B-Instruct-DPE-by-dim" or args.model_name == "Qwen2-72B-Instruct-DPE-by-dim":
        print("===== qwen-DPE-by-dim =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    elif args.model_name == "Mistral-7B-Instruct-v0.2-DPE-by-dim" or args.model_name == "Mistral-7B-Instruct-v0.3-DPE-by-dim":
        print("===== mistral-DPE-by-dim =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    
    ### DPE-by-dim-topk
    elif args.model_name == "llama3-8b-Instruct-DPE-by-dim-topk":
        print("===== llama-DPE-by-dim-topk =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_topk{args.topk}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        selected_dim = torch.load(f"models/llama_DPE_topk/qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    elif args.model_name == "Mistral-7B-Instruct-v0.2-DPE-by-dim-topk":
        print("===== mistral-DPE-by-dim-topk =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_topk{args.topk}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        selected_dim = torch.load(f"models/llama_DPE_topk/Mistral-7B-Instruct-v0.2_qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    elif args.model_name == "Qwen2.5-7B-Instruct-DPE-by-dim-topk":
        print("===== qwen-DPE-by-dim-topk =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_topk{args.topk}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        selected_dim = torch.load(f"models/llama_DPE_topk/Qwen2.5-7B-Instruct_qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    
    # DPE-by-dim-lowk
    elif args.model_name == "llama3-8b-Instruct-DPE-by-dim-lowk":
        print("===== llama-DPE-by-dim-lowk =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_topk{args.topk}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        selected_dim = torch.load(f"models/llama_DPE_topk/qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(-selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    elif args.model_name == "Mistral-7B-Instruct-v0.2-DPE-by-dim-lowk":
        print("===== mistral0.2-DPE-by-dim-lowk =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_topk{args.topk}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        selected_dim = torch.load(f"models/llama_DPE_topk/Mistral-7B-Instruct-v0.2_qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(-selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import MistralForCausalLM
        MistralForCausalLM.forward = causal_forward
        model = MistralForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    elif args.model_name == "Qwen2.5-7B-Instruct-DPE-by-dim-lowk":
        print("===== qwen-DPE-by-dim-lowk =====")
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_topk{args.topk}_g{args.group_sizes}.jsonl")
        from models.llama_dim_selfextend_by_dim_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_sizes_list = [int(num) for num in args.group_sizes.split('-')]
        selected_dim = torch.load(f"models/llama_DPE_topk/Qwen2.5-7B-Instruct_qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(-selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        print(f'dimensional selfextend config: using group size {group_sizes_list}, dim_range{args.all_dims}, window size {window_size}')
        use_flash = True
        from transformers import Qwen2ForCausalLM
        Qwen2ForCausalLM.forward = causal_forward
        model = Qwen2ForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        SelfExtend.apply(model, group_sizes_list, window_size, args.all_dims, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"    
    
    ### topk
    elif args.model_name == "llama3.1-8b-Instruct-DPE-topk":
        print(f"testing {args.model_name}")
        topk = args.topk
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_gl{args.group_size_low}_gm{args.group_size_mid}_gh{args.group_size_high}_ga{args.group_size_all}_low{args.low_dim}_high{args.high_dim}_top{topk}.jsonl")
        config = transformers.AutoConfig.from_pretrained(
            args.model_path,
        )
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            args.model_path,
        )
        print(config.rope_scaling)
        from models.llama_dim_selfextend_topk import SelfExtend
        window_size = args.window_size
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high, "all_dim":args.group_size_all}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        # print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        # selected_dim = torch.load("models/llama_DPE_topk/total_frequency_statistics_results.pt")
        # selected_dim = (torch.topk(selected_dim['total_all_mean_freq'], topk, dim=-1)[1]).to('cuda')
        selected_dim = torch.load(f"models/llama_DPE_topk/qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(selected_dim.mean(dim=0).mean(dim=0).mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        use_flash = True
        from models.string_for_llama import causal_forward
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, dim_range, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn")
    ## DPE-topk-by-heads
    elif args.model_name == "llama3-8b-Instruct-DPE-topk-by-heads":
        print(f"testing {args.model_name}")
        topk = args.topk
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_gl{args.group_size_low}_gm{args.group_size_mid}_gh{args.group_size_high}_ga{args.group_size_all}_low{args.low_dim}_high{args.high_dim}_top{topk}.jsonl")
        from models.llama_dim_selfextend_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high, "all_dim":args.group_size_all}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        selected_dim = torch.load(f"models/llama_DPE_topk/qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, dim_range, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn")
    elif args.model_name == "llama3.1-8b-Instruct-DPE-topk-by-heads":
        print('-----llama3.1-8b-Instruct-DPE-topk-by-heads-----')
        topk = args.topk
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_gl{args.group_size_low}_gm{args.group_size_mid}_gh{args.group_size_high}_ga{args.group_size_all}_low{args.low_dim}_high{args.high_dim}_top{topk}.jsonl")
        from models.llama_dim_selfextend_topk_by_heads import SelfExtend
        window_size = args.window_size
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high, "all_dim":args.group_size_all}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        selected_dim = torch.load(f"models/llama_DPE_topk/llama3.1_qk_2_norm_selected_dim.pt")
        selected_dim = (torch.topk(selected_dim.mean(dim=0).mean(dim=0), dim=-1, k=args.topk)[1]).to('cuda')
        use_flash = True
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, dim_range, selected_dim, enable_flash_attention=use_flash, flash_attention_impl="flash_attn")
    
    
    elif args.model_name == "llama3-8b-Instruct-dimrerope":
        file_name = os.path.join(pred_save_path, f"seqlen{args.seq_len}_w{args.window_size}_dl{args.detect_low}_dm{args.detect_mid}_dh{args.detect_high}_low{args.low_dim}_high{args.high_dim}.jsonl")
        config = transformers.AutoConfig.from_pretrained(
            args.model_path,
        )
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            args.model_path,
        )
        # print(config.rope_scaling)
        from models.llama_dim_rerope_all import SelfExtend
        window_size = args.window_size
        detect_length = {"low_dim": args.detect_low,"mid_dim": args.detect_mid,"high_dim": args.detect_high}
        group_size = {"low_dim": args.group_size_low, "mid_dim": args.group_size_mid, "high_dim": args.group_size_high}
        dim_range = {"low_dim": args.low_dim, "high_dim": args.high_dim}
        print(f'dimensional selfextend config: using group size {group_size}, dim_range{dim_range}, window size {window_size}')
        use_flash = True
        from models.string_for_llama import causal_forward
        from transformers import LlamaForCausalLM
        LlamaForCausalLM.forward = causal_forward
        model = LlamaForCausalLM.from_pretrained(args.model_path, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
        print(f'using group size {group_size} using window size {window_size}')
        SelfExtend.apply(model, group_size, window_size, detect_length, dim_range, enable_flash_attention=use_flash, flash_attention_impl="flash_attn") ## flash_attention_impl="triton" or "flash_attn"

    
    else:
        raise ValueError(f"Invalid model name: {args.model_name}")

    model = model.eval()
    sys.exit(main(args.infer_size))


