import os
import random
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from human_eval.data import read_problems, write_jsonl
from human_eval.evaluation import evaluate_functional_correctness
from SLICER_config import SCConfig
from SLICER_opti import *

# CUDA device configuration
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"=== SEED={SEED} set ===")

# Utility: Extract function body from LLM output.
def extract_function_only(text: str) -> str:
    """
    Extract only the function definition from the model output.
    Removes markdown code fences and leading context.
    """
    if "def " in text:
        text = text[text.index("def "):]
    if "```" in text:
        text = text.split("```")[0]
    return text.strip()

# SLICER-enabled sequence generation function, This function mixes normal decoding and SLICER intervention
def custom_generate_text(
    model,
    tokenizer,
    sc_config,
    input_ids: torch.Tensor,
    w_bar: int = 200,
    max_new_tokens: int = 256,
    temperature: float = 0.2,
    top_p: float = 0.95,
    device: torch.device = None,
):
    """
    Generate text with a hybrid method:
    (1) Normal decoding up to w_bar tokens
    (2) Apply SLICER once at w_bar position
    (3) Continue normal decoding
    Returns: (generated text, out_slicer, applied)
    """
    input_ids = input_ids.to(device)
    generated = input_ids.clone()
    input_len = input_ids.shape[1]
    past_key_values = None
    applied = False
    out_slicer = None
    new_tokens = 0

    def decode_slice():
        # Only decode new tokens, not the prompt
        gen_ids = generated[0, input_len:]
        return tokenizer.decode(gen_ids, skip_special_tokens=True)

    # (1) Normal autoregressive generation until w_bar tokens
    while generated.size(1) < w_bar and new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]
        nxt = torch.multinomial(
            torch.nn.functional.softmax(logits / temperature, dim=-1),
            num_samples=1
        )
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), None, False

    # (2) Apply SLICER at w_bar position (one step)
    if new_tokens < max_new_tokens:
        with torch.no_grad():
            out = model(
                input_ids=generated,
                past_key_values=None,
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,
                sc_config=sc_config
            )
        applied = True
        out_slicer = out
        logits = out.logits[:, -1, :]
        nxt = torch.multinomial(
            torch.nn.functional.softmax(logits / temperature, dim=-1),
            num_samples=1
        )
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), out_slicer, applied

    # (3) Resume normal decoding until max_new_tokens
    past_key_values = None
    while new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]
        nxt = torch.multinomial(
            torch.nn.functional.softmax(logits / temperature, dim=-1),
            num_samples=1
        )
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            break

    return decode_slice(), out_slicer, applied

# Main function: HumanEval code generation
def main():
    # (1) Load HumanEval problems (OpenAI benchmark)
    problems = read_problems()
    print(f"Total HumanEval tasks: {len(problems)}")
    first_task_id = list(problems.keys())[0]
    print("\nSample prompt:\n", problems[first_task_id]["prompt"])

    # (2) Load tokenizer and model (Qwen2.5-Coder-7B-Instruct)
    model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        use_fast=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype="auto",
        trust_remote_code=True
    )
    model.eval()
    print("=== Qwen2.5-Coder-7B-Instruct model loaded ===")

    # (3) Prepare SLICER config
    sc_config = SCConfig(
        split_layer=20,   # Layer where SLICER is injected
        s=0.7,            # Compression ratio (custom)
        lambd=0.0,
        delta=0.0,
        Q=[8, 8, 8],      # Quantization bits (example)
        Q_n=[8, 8, 8]
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # (4) Loop: Generate code for each problem with SLICER
    samples = []
    num_per_problem = 1
    MAX_NEW_TOKENS = 512
    w_bar = 50  # Apply SLICER after 50 generated tokens

    for task_id, task in tqdm(problems.items(), desc="Tasks"):
        for _ in range(num_per_problem):
            # Prepare input using chat template for code generation
            messages = [{"role": "user", "content": task["prompt"]}]
            inputs = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(device)

            # Generate output with SLICER hybrid method
            completion, _, applied = custom_generate_text(
                model=model,
                tokenizer=tokenizer,
                sc_config=sc_config,
                input_ids=inputs,
                w_bar=w_bar,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=0.2,
                top_p=0.95,
                device=device,
            )

            # Only keep function definition (strip markdown, etc.)
            clean_completion = extract_function_only(completion)
            samples.append({
                "task_id": task_id,
                "completion": clean_completion
            })

    print(f"Generated {len(samples)} completions. ({num_per_problem} per problem)")

    # (5) Save completions to JSONL file for later evaluation
    jsonl_file = "Pass@1_samples.jsonl"
    write_jsonl(jsonl_file, samples)
    print(f"Samples written to {jsonl_file}")

    # (6) Evaluate functional correctness with HumanEval official metric
    results = evaluate_functional_correctness(jsonl_file, k=[1], n_workers=4)
    print(results)
    print(f"Pass@1 accuracy: {results['pass@1'] * 100:.2f}%")

if __name__ == "__main__":
    main()
