import argparse
import json
import os
import time
import torch
import numpy as np

from fastchat.llm_judge.common import load_questions
from fastchat.model import get_conversation_template
from fastchat.utils import str_to_torch_dtype
from tqdm import tqdm

from eagle.model.ea_model import EaModel
from eagle.model.kv_cache import initialize_past_key_values
from eagle.model.utils import *


def generate_linear_tree_buffers(k, device="cuda"):
    tree_choices = [[0] * (i + 1) for i in range(k)]
    tree_buffers = generate_tree_buffers(tree_choices, device=device)
    
    return tree_buffers


@torch.inference_mode()
def vanilla_speculative_forward(
    inputs,
    model,
    tokenizer,
    max_new_tokens,
    k=5,
    logits_processor=None,
    max_steps=512
):
    input_ids = inputs.input_ids
    assert input_ids.shape[0] == 1, "Only support batch size 1"
    input_ids = input_ids.clone()
    
    model.ea_layer.reset_kv()
    accept_length_list = []
    
    tree_buffers = generate_linear_tree_buffers(k, device=model.base_model.model.layers[-1].self_attn.q_proj.weight.device)
    tree_buffers["retrieve_indices_head"] = tree_buffers["retrieve_indices"].to(model.base_model.lm_head.weight.device)
    
    if hasattr(model, "past_key_values"):
        past_key_values = model.past_key_values
        past_key_values_data = model.past_key_values_data
        current_length_data = model.current_length_data
        current_length_data.zero_()
    else:
        (
            past_key_values,
            past_key_values_data,
            current_length_data,
        ) = initialize_past_key_values(model.base_model)
        model.past_key_values = past_key_values
        model.past_key_values_data = past_key_values_data
        model.current_length_data = current_length_data
    
    input_len = input_ids.shape[1]
    cur_length = input_len
    
    reset_tree_mode(model)
    tree_logits, logits, hidden_state, sample_token = initialize_tree(
        input_ids, model, tree_buffers["tree_attn_mask"], past_key_values, logits_processor
    )
    new_token = 0
    
    for idx in range(max_steps):
        candidates, cart_candidates_prob, tree_candidates = generate_candidates(
            tree_logits,
            tree_buffers["tree_indices"],
            tree_buffers["retrieve_indices"],
            sample_token,
            logits_processor
        )
        
        logits, hidden_state_new, outputs = tree_decoding(
            model,
            tree_candidates,
            past_key_values,
            tree_buffers["tree_position_ids"],
            input_ids,
            tree_buffers["retrieve_indices_head"],
        )
        
        best_candidate, accept_length, sample_p = evaluate_posterior(
            logits, candidates, logits_processor, cart_candidates_prob,
            tree_logits[2], tree_buffers["p_indices"],
            tree_candidates, tree_buffers["b_indices"]
        )
        
        input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs(
            input_ids,
            candidates,
            best_candidate,
            accept_length,
            tree_buffers["retrieve_indices"],
            logits_processor,
            logits,
            tree_logits,
            new_token,
            past_key_values_data,
            current_length_data,
            model,
            hidden_state,
            hidden_state_new,
            sample_p
        )
        
        accept_length_tree = input_ids.shape[1] - cur_length
        cur_length = accept_length_tree + cur_length
        accept_length_list.append(accept_length_tree)
        
        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            for i, id in enumerate(input_ids[0, input_len:]):
                if id == tokenizer.eos_token_id:
                    eos_token_ids_index = i
            invalid_len = len(input_ids[0, input_len:]) - eos_token_ids_index - 1
            if invalid_len > 0:
                accept_length_list[-1] -= invalid_len
                new_token -= invalid_len
            break
        if new_token > max_new_tokens:
            break
        if input_ids.shape[1] > 1960:
            break
    
    return input_ids, new_token, idx + 1, accept_length_list


def run_eval(
    model,
    tokenizer,
    forward_func,
    question_file,
    question_begin,
    question_end,
    max_new_tokens,
    num_choices,
    num_gpus_per_model,
    num_gpus_total,
    **kwargs
):
    print("Entering run_eval for Vanilla Speculative Decoding")
    
    questions = load_questions(question_file, question_begin, question_end)
    print(f"Loaded {len(questions)} questions from {question_file}")
    
    output_file = question_file.replace('.jsonl', '_vanilla_sd.jsonl')
    
    get_model_answers(
        model,
        tokenizer,
        forward_func,
        questions,
        output_file,
        max_new_tokens,
        num_choices,
        **kwargs
    )


@torch.inference_mode()
def get_model_answers(
    model,
    tokenizer,
    forward_func,
    questions,
    output_file,
    max_new_tokens,
    num_choices,
    **kwargs
):
    model.eval()
    
    # Warm up
    if questions:
        first_question = questions[0]
        conv = get_conversation_template("vicuna")
        for j in range(len(first_question["conversations"])):
            role, qs = first_question["conversations"][j]
            if role == "ASSISTANT":
                continue
            conv.append_message(role, qs)
            conv.append_message("ASSISTANT", None)
            conv.stop_str = "</s>"
            prompt = conv.get_prompt()
            inputs = tokenizer([prompt], return_tensors="pt", max_length=2000, truncation=True).to("cuda")
            try:
                _ = forward_func(
                    inputs,
                    model,
                    tokenizer,
                    max_new_tokens=50,
                    **kwargs
                )
            except Exception as e:
                print(f"Warm up error: {e}")
    
    accept_lengths_tree = []
    
    for question in tqdm(questions):
        choices = []
        
        for i in range(num_choices):
            cur_accept_lengths_tree = []
            torch.manual_seed(i)
            conv = get_conversation_template("vicuna")
            turns = []
            steps = []
            new_tokens = []
            wall_time = []
            
            for j in range(len(question["conversations"])):
                role, qs = question["conversations"][j]
                if role == "ASSISTANT":
                    continue
                
                conv.append_message(role, qs)
                conv.append_message("ASSISTANT", None)
                conv.stop_str = "</s>"
                prompt = conv.get_prompt()
                
                inputs = tokenizer([prompt], return_tensors="pt", max_length=2000, truncation=True).to("cuda")
                input_ids = inputs.input_ids
                
                try:
                    torch.cuda.synchronize()
                    start_time = time.time()
                    
                    output_ids, new_token, step, accept_length_tree = forward_func(
                        inputs,
                        model,
                        tokenizer,
                        max_new_tokens,
                        **kwargs
                    )
                    
                    torch.cuda.synchronize()
                    total_time = time.time() - start_time
                    
                    accept_lengths_tree.extend(accept_length_tree)
                    output_ids = output_ids[0][len(input_ids[0]):]
                    
                    if conv.stop_token_ids:
                        stop_token_ids_index = [
                            idx for idx, id in enumerate(output_ids)
                            if id in conv.stop_token_ids
                        ]
                        if len(stop_token_ids_index) > 0:
                            output_ids = output_ids[:stop_token_ids_index[0]]
                    
                    output = tokenizer.decode(output_ids, spaces_between_special_tokens=False)
                    
                    if conv.stop_str and output.find(conv.stop_str) > 0:
                        output = output[:output.find(conv.stop_str)]
                    
                    for special_token in tokenizer.special_tokens_map.values():
                        if isinstance(special_token, list):
                            for special_tok in special_token:
                                output = output.replace(special_tok, "")
                        else:
                            output = output.replace(special_token, "")
                    output = output.strip()
                    
                    if conv.name == "xgen" and output.startswith("Assistant:"):
                        output = output.replace("Assistant:", "", 1).strip()
                    
                except RuntimeError as e:
                    import traceback
                    print(f"ERROR question ID: {question['id']}")
                    print(f"Error details: {str(e)}")
                    traceback.print_exc()
                    output = "ERROR"
                    step = 0
                    new_token = 0
                    total_time = 0
                    accept_length_tree = []
                
                if role == "USER":
                    turns.append(output)
                    steps.append(int(step))
                    new_tokens.append(int(new_token))
                    wall_time.append(total_time)
                    cur_accept_lengths_tree.extend(accept_length_tree)
                    conv.messages[-1][-1] = output
            
            choices.append({
                "index": i,
                "turns": turns,
                "decoding_steps": steps,
                "new_tokens": new_tokens,
                "wall_time": wall_time,
                "accept_lengths": cur_accept_lengths_tree
            })
        
        new_question = question.copy()
        
        if choices and choices[0]["turns"]:
            model_responses = choices[0]["turns"]
            updated_conversations = []
            assistant_turn_idx = 0
            
            for role, content in question["conversations"]:
                if role == "USER":
                    updated_conversations.append([role, content])
                elif role == "ASSISTANT":
                    if assistant_turn_idx < len(model_responses):
                        updated_conversations.append([role, model_responses[assistant_turn_idx]])
                        assistant_turn_idx += 1
                    else:
                        updated_conversations.append([role, content])
            
            new_question["conversations"] = updated_conversations
            
            conv = get_conversation_template("vicuna")
            conv.messages = []
            for role, content in updated_conversations:
                conv.append_message(role, content)
            
            new_question["prompt"] = conv.get_prompt()
            new_question["new_tokens"] = choices[0]["new_tokens"]
            new_question["wall_time"] = choices[0]["wall_time"]
            new_question["accept_lengths"] = choices[0]["accept_lengths"]
        
        os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
        with open(output_file, "a", encoding="utf-8") as fout:
            fout.write(json.dumps(new_question, ensure_ascii=False) + "\n")
    
    mean_accept = np.mean(accept_lengths_tree) if accept_lengths_tree else 0
    print(f"\n# Total accept steps: {len(accept_lengths_tree)}")
    print(f"# Accept lengths stats: min={min(accept_lengths_tree) if accept_lengths_tree else 0}, "
          f"max={max(accept_lengths_tree) if accept_lengths_tree else 0}")
    print(f"# Mean accepted tokens per step: {mean_accept:.2f}")
    print(f"# Results saved to: {output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Vanilla Speculative Decoding Inference")
    
    parser.add_argument(
        "--base-model-path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--draft-model-path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--question-file",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--question-end",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=1024,
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=512,
    )
    parser.add_argument(
        "--k",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--num-gpus-total",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float32", "float64", "float16", "bfloat16"]
    )
    
    args = parser.parse_args()
    
    print("=" * 80)
    print("Vanilla Speculative Decoding Configuration:")
    print(f"  Base Model (Target): {args.base_model_path}")
    print(f"  EA Model (Draft): {args.draft_model_path}")
    print(f"  k (linear path length): {args.k}")
    print(f"  Temperature: {args.temperature}")
    print("=" * 80)
    
    ea_config_path = os.path.join(args.draft_model_path, "config.json")
    source_config_path = "./config/config.json"
    
    if not os.path.exists(ea_config_path) and os.path.exists(source_config_path):
        print(f"Config file not found at {ea_config_path}, copying from {source_config_path}")
        import shutil
        try:
            os.makedirs(args.draft_model_path, exist_ok=True)
            shutil.copy2(source_config_path, ea_config_path)
            print(f"Successfully copied config.json to {ea_config_path}")
        except Exception as e:
            print(f"Warning: Failed to copy config.json: {e}")
    
    print("\nLoading EAGLE model...")
    model = EaModel.from_pretrained(
        base_model_path=args.base_model_path,
        ea_model_path=args.draft_model_path,
        torch_dtype=str_to_torch_dtype(args.dtype),
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    
    tokenizer = model.get_tokenizer()
    
    if args.temperature > 1e-5:
        logits_processor = prepare_logits_processor(temperature=args.temperature)
    else:
        logits_processor = None
    
    run_eval(
        model=model,
        tokenizer=tokenizer,
        forward_func=vanilla_speculative_forward,
        question_file=args.question_file,
        question_begin=args.question_begin,
        question_end=args.question_end,
        max_new_tokens=args.max_new_tokens,
        num_choices=args.num_choices,
        num_gpus_per_model=args.num_gpus_per_model,
        num_gpus_total=args.num_gpus_total,
        k=args.k,
        logits_processor=logits_processor,
        max_steps=args.max_steps
    )
