import json
from pathlib import Path
import time
from typing import List, Tuple, Any

import torch
from torch import Tensor
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from datasets import load_dataset
# import tqdm
from tqdm import tqdm


from eval_utils import (
    dump_jsonl,
    create_prompt,
    get_answer,
    load_data,
    DATA_NAME_TO_MAX_NEW_TOKENS,
)

from argparse import ArgumentParser, Namespace
from eval_utils import DATA_NAME_TO_MAX_NEW_TOKENS
from model_loader import *


MAX_POSITION_ID = 128 * 1024  # Determined by the model
TRUNCATE_LEN = 128 * 1024


import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"



def parse_args():
    p = ArgumentParser()
    p.add_argument(
        "--task",
        type=str,
        # choices=list(DATA_NAME_TO_MAX_NEW_TOKENS.keys()) + ["all"],
        required=True,
        help="Which task to use. Note that \"all\" can only be used in `compute_scores.py`.",  # noqa
    )
    p.add_argument(
        '--data_dir',
        type=str,
        default='../data',
        help="The directory of data."
    )
    p.add_argument("--output_dir", type=str, default="../results", help="Where to dump the prediction results.")  # noqa
    p.add_argument(
        "--model_path",
        type=str,
        help="The path of the model (in HuggingFace (HF) style). If specified, it will try to load the model from the specified path, else, it wll default to the official HF path.",  # noqa
    )  # noqa
    p.add_argument(
        "--model_name",
        type=str,
        choices=["gpt4", "yarn-mistral", "kimi", "claude2", "rwkv", "yi-6b-200k", "yi-34b-200k", "chatglm3","llama3.1-8b-ins"],
        default="gpt4",
        help="For `compute_scores.py` only, specify which model you want to compute the score for.",  # noqa
    )
    p.add_argument("--start_idx", type=int, default=0, help="The index of the first example to infer on. This is used if you want to evaluate on a (contiguous) subset of the data.")  # noqa
    p.add_argument("--stop_idx", type=int, help="The index of the last example to infer on. This is used if you want to evaluate on a (contiguous) subset of the data. Defaults to the length of dataset.")  # noqa
    p.add_argument("--verbose", action='store_true')
    p.add_argument("--device", type=str, default="cuda")

    return p


def truncate_input(input: list, max_length: int, manner="middle"):
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None


def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)

    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)  # type: ignore
    
    assert len_after <= len_before
    assert len_after <= max_tokens
    return tok.decode(tokens, skip_special_tokens=True)


def chunk_generate(
    model,
    tok,
    texts: List[str],
    max_tokens: int,
    sliding_window: int = 128 * 1024,
    chunk_size: int = 2500,
    verbose: bool = False,
) -> List[str]:
    """
    Directly performing inference using HF transformers will result in OOM
    when using one A100 GPU. This is because the attention matrix is too large,
    so we chunk the input up and perform forward pass on each chunk to build
    up the KV cache. Note that each token still has to attend to
    all tokens in the past.
    """
    with torch.no_grad():
        """
        input_ids: (b, n)
        attention_mask: (b, n)
        [
            [0, 0, .., 0, 1, 1, ..., 1]
            ...
        ]
        """
        inputs = tok(texts, return_tensors="pt", padding=True)
        inputs = inputs.to(model.device)  # type: ignore
        input_ids: Tensor = inputs.input_ids  # (b, n)
        attention_mask: Tensor = inputs.attention_mask  # (b, n)
        position_ids: Tensor = attention_mask.long().cumsum(dim=-1) - 1
        position_ids.masked_fill_(attention_mask == 0, value=1)
        seq_len = input_ids.shape[-1]
        #print("seq_len:", seq_len)
        kv_cache: Any = None
        # Split into chunks for pre-filling
        chunk_idxs = []
        n = seq_len - 1
        while n > 0:
            chunk_idxs.append(n)
            n -= chunk_size
        chunk_idxs.append(0)
        chunk_idxs = chunk_idxs[::-1]
        chunk_lo = chunk_idxs[:-1]
        chunk_hi = chunk_idxs[1:]
        #print(f"Number of chunks: {len(chunk_lo)}, generating...")
        start_time = time.time()
        for chunk_i, (chunk_lo, chunk_hi) in enumerate(
            zip(chunk_lo, chunk_hi)
        ):
            if verbose:
                print(
                    f"[chunk {chunk_i}] {chunk_lo} : {chunk_hi}",
                    round(time.time() - start_time),
                )
            chunk_input_ids = input_ids[:, chunk_lo:chunk_hi]
            if kv_cache is not None:
                mask_start_idx = chunk_lo - kv_cache[0][0].shape[2]
            else:
                mask_start_idx = chunk_lo
            chunk_attention_mask = attention_mask[:, mask_start_idx:chunk_hi]
            chunk_position_ids = position_ids[:, chunk_lo:chunk_hi]
            outputs: BaseModelOutputWithPast = model.model.forward(
                input_ids=chunk_input_ids,
                attention_mask=chunk_attention_mask,
                position_ids=chunk_position_ids,
                past_key_values=kv_cache,
                return_dict=True,
                use_cache=True,
            )
            kv_cache = outputs.past_key_values
            # Discard KV states on the left beyond the window
            new_cache = ()
            n_layers = len(kv_cache)
            for layer_i in range(n_layers):
                keys = kv_cache[layer_i][0][:, :, -sliding_window:]
                values = kv_cache[layer_i][1][:, :, -sliding_window:]
                new_cache += ((keys, values),)
            kv_cache = new_cache
        kv_cache_len = kv_cache[0][0].shape[2]
        outputs = model.generate(
            input_ids=input_ids[:, -1:],
            attention_mask=attention_mask[:, -kv_cache_len - 1 :],
            max_new_tokens=max_tokens,
            past_key_values=kv_cache,
            eos_token_id=tok.pad_token_id,
            use_cache=True,
        )
        responses = [
            tok.decode(t[1:], skip_special_tokens=True) for t in outputs
        ]
    return responses


def get_pred(
    model,
    tok: AutoTokenizer,
    input_text: str,
    max_tokens: int,
    verbose: bool = False,
) -> str:
    """
    Truncate down to 128k then make inference.
    """
    # print("Truncating...")
    #input_text = truncate_by_tokens(input_text, tok, TRUNCATE_LEN)
    if verbose:
        print("# chars:", len(input_text))
        print("=============== Input ===============")
        print(input_text[:200])
        print("...")
        print(input_text[-200:])
        print("=====================================")
    # output = chunk_generate(
    #     model,
    #     tok,
    #     [input_text],
    #     max_tokens=max_tokens,
    #     chunk_size=128,
    #     verbose=verbose,
    # )[0]

    prompt = tok(input_text, truncation=True, padding=True, max_length=MAX_POSITION_ID, return_tensors="pt")
    input_ids = prompt['input_ids'].to(model.device)
    attention_mask = prompt["attention_mask"].to(model.device)
        
    output_ids = model.generate(
            input_ids,
            attention_mask = attention_mask,
            output_attentions=False,
            max_new_tokens=max_tokens,
            num_beams=1,
            do_sample=False,
            #temperature=0.7,
            eos_token_id=[tok.eos_token_id, tok.encode("\n", add_special_tokens=False)[-1]],
            pad_token_id=tok.eos_token_id
        )
    
    output = tok.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
   # print("Chunked generation:", output)
    return output


def load_model(model_path):
    tok = AutoTokenizer.from_pretrained(model_path)
    tok.pad_token = tok.eos_token
    model =  load_model_and_apply_patches(model_path, args)
    return model, tok  # type: ignore


def download_infity_benchmark():
    from datasets import load_dataset, Value, Sequence, Features


    ft = Features({"id": Value("int64"), "context": Value("string"), "input": Value("string"), "answer": Sequence(Value("string")), "options": Sequence(Value("string"))})
    print("Downloading InfiniteBench...")
    dataset = load_dataset("xinrongzhang2022/InfiniteBench", features=ft)

    


if __name__ == "__main__":
    #download_infity_benchmark()

    
    args = add_args(parse_args()).parse_args()
    model_name = args.model_name

    print(json.dumps(vars(args), indent=4))
    data_name = args.task

    # Model
    max_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name]
    model, tok = load_model(args.model_path)

    # Data
    result_dir = Path(f"./results_infity/{model_name}")
    result_dir.mkdir(exist_ok=True, parents=True)
    examples = load_data(data_name, data_dir=args.data_dir)

    extension = "yarn" if args.yarn else "mrrope"

    if args.stop_idx is None:
        args.stop_idx = len(examples)
        output_path = (result_dir / f"{extension}_{data_name}.jsonl")
    else:
        output_path = (result_dir / f"{extension}_{data_name}_{args.start_idx}-{args.stop_idx}.jsonl")

    preds = []
    print("==== Evaluation ====")
    print(f"# examples: {len(examples)}")
    print(f"Start index: {args.start_idx}")
    print(f"Stop index: {args.stop_idx}")
    print(f"Verbose: {args.verbose}")
    print(f"Max tokens: {max_tokens}")

    bar = tqdm(total=len(examples))
    stop_bound = min(len(examples), args.stop_idx)
    for i in range(args.start_idx, stop_bound):
        eg = examples[i]
        input_text = create_prompt(eg, data_name, model_name, args.data_dir)
        #print(f"====== Example {i} ======")
        pred = get_pred(
            model, tok, input_text, max_tokens=max_tokens, verbose=args.verbose
        )
        if args.verbose:
            print(pred)
        preds.append(
            {
                "id": i,
                "prediction": pred,
                "ground_truth": get_answer(eg, data_name),
            }
        )
        dump_jsonl(preds, output_path)

        bar.update()


