import json
import os
from pathlib import Path
import time
from typing import List, Tuple, Any
from tqdm import tqdm

import torch
from torch import Tensor
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
from transformers.modeling_outputs import BaseModelOutputWithPast

from eval_utils import (
    check_benchmark_availability,
    dump_jsonl,
    create_prompt,
    load_data,
    get_answer,
    DATA_NAME_TO_MAX_NEW_TOKENS,
)
from compute_scores import compute_scores
from args import parse_args
from sparq import *
from modeling_llama_chunck_topk import *

# sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def truncate_input(input: list, max_length: int, manner="middle"):
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None

def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)
    print(f"# tokens before: {len_before}")
    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)  # type: ignore
    print(f"# tokens after: {len_after}")
    assert len_after <= len_before
    assert len_after <= max_tokens
    return tokens

def get_pred(
    model,
    tok: AutoTokenizer,
    input_text: str,
    max_input_length: int,
    verbose: bool = False,
    generation_config: GenerationConfig = None,
) -> str:
    """
    Truncate down to 128k then make inference.
    """
    input_tokens = truncate_by_tokens(input_text, tok, max_input_length)
    if verbose:
        print("# tokens:", len(input_tokens))
        print("=============== Input ===============")
        print(tok.decode(input_tokens[:200]))
        print("...")
        print(tok.decode(input_tokens[-200:]))
        print("=====================================")
    input_tensors = {"input_ids": torch.tensor(input_tokens).unsqueeze(0).to(model.device)}
    outputs = model.generate(**input_tensors, generation_config=generation_config)

    output = outputs[0, len(input_tokens):]
    output = tok.decode(output, skip_special_tokens=True)
    print(input_text[:5000], input_text[-5000:])
    print("Chunked generation:", output)
    return output

def load_model(
    model_name: str, topk: int=-1, topk_from_layer: int=-1, topk_dims_file_path: str="", use_sparq: bool = False
):
    config = AutoConfig.from_pretrained(model_name)
    if "LWM" in model_name:
        c = {
            'theta': 10000000,
            'max_sequence_length': 131072,
            'scan_attention': True,
            'scan_query_chunk_size': 1024,
            'scan_key_chunk_size': 1024,
            'scan_mlp': True,
            'scan_mlp_chunk_size': 1024,
            'scan_layers': True
        }
        config.update(c)
    if topk != -1:
        config.topk = topk
        config.topk_from_layer = topk_from_layer
    if topk_dims_file_path:
        config.topk_dims_file_path = topk_dims_file_path

    if use_sparq:
        config.topk = 256
        config.local_window = 100
        config.num_top_dim_in_q = 16

    tok = AutoTokenizer.from_pretrained(model_name)
    tok.pad_token = tok.eos_token
    llm = AutoModelForCausalLM.from_pretrained(
        model_name,
        config=config,
        torch_dtype="auto",
        device_map="cuda",
    )
    if use_sparq:
        llm = apply_sparq(llm)
    print("Model and tokenizer loaded.")
    return llm, tok

if __name__ == "__main__":
    args = parse_args()
    
    check_benchmark_availability(args.data_dir)
    model_name = args.model_name_or_path
    max_seq_length = args.max_seq_length
    real_model_name = model_name.split("/")[-1]
    data_name = args.task

    # Model
    max_new_tokens = DATA_NAME_TO_MAX_NEW_TOKENS[data_name]
    model, tok = load_model(model_name, args.topk, args.topk_from_layer, args.topk_dims_file_path, args.use_sparq)
    generation_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        num_return_sequences=1,
        do_sample=False,
        temperature=0,
        # top_p=0.95,
        pad_token_id=tok.pad_token_id,
        use_cache=False,
    )

    # Data
    result_dir = Path(args.output_dir, real_model_name)
    result_dir.mkdir(exist_ok=True, parents=True)
    output_path = result_dir / f"prediction_{data_name}.jsonl"
    examples = load_data(data_name, data_dir=args.data_dir)

    if args.num_eval_examples != -1:
        num_eval_examples = min(args.num_eval_examples, len(examples))
        examples = examples[: num_eval_examples]

    preds = []
    print("==== Evaluation ====")
    print(f"# examples: {len(examples)}")
    print(f"Num eval examples: {args.num_eval_examples}")
    print(f"Verbose: {args.verbose}")
    print(f"Max new tokens: {max_new_tokens}")
    
    if os.path.exists(output_path) and not args.rewrite:
        print(f"Output file {output_path} exists. Loading from file.")
        compute_scores(output_path, data_name, real_model_name)

    for i, eg in tqdm(enumerate(examples)):
        input_text = create_prompt(eg, data_name, real_model_name, args.data_dir)
        print(f"====== Example {i} ======")
        pred = get_pred(
            model, tok, input_text, 
            max_input_length=max_seq_length-max_new_tokens,
            verbose=args.verbose, generation_config=generation_config
        )
        print("Ground Truth", get_answer(eg, data_name))
        if args.verbose:
            print(pred)
        preds.append(
            {
                "id": i,
                "prediction": pred,
                "ground_truth": get_answer(eg, data_name),
            }
        )
        dump_jsonl(preds, output_path)
        # if i == 1:
        #     assert False

    compute_scores(output_path, data_name, real_model_name)