import os
import torch
import random
import numpy as np
import json
import gzip
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
from human_eval.evaluation import evaluate_functional_correctness
from SLICER_opti import *
from SLICER_config import SCConfig

# Set which GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# Fix random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Load the Meta-Llama-3-8B-Instruct model and its tokenizer from Hugging Face Hub
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",      # Automatically spread the model across available GPUs
    torch_dtype=torch.float32  # Use FP16 precision for efficiency
)
model.eval()  # Set the model to evaluation mode (no training)

# Load HumanEval problems from the compressed .jsonl.gz file
problems = []
with gzip.open("./data/human-eval/data/HumanEval.jsonl.gz", "rt") as f:
    for line in f:
        obj = json.loads(line)
        problems.append({
            "task_id": obj["task_id"],                   # Unique ID for each problem
            "prompt": obj["prompt"].strip(),             # Problem description
            "canonical_solution": obj["canonical_solution"].strip()  # Reference solution
        })

# Set up the SLICER (Split Computing + Quantization) configuration for inference
sc_config = SCConfig(
    split_layer=20,     # Layer at which to split the model for edge/cloud
    s=0.9,              # Some SLICER-specific parameter (see your config)
    Q=[8, 8, 8],        # Quantization bits for each segment
    Q_n=[8, 8, 8],      # Another quantization config (per segment)
    lambd=0.0,          # SLICER regularization parameter (if used)
    delta=0.0,          # SLICER parameter
    use_ABSQ=True       # Whether to use Adaptive Bit Split Quantization
)
print("SLICER Config:", sc_config)

def build_prompt(task_prompt: str, func_signature: str) -> str:
    """
    Construct a prompt for Llama-3 in the system/user/assistant chat format.
    - task_prompt: The programming problem description.
    - func_signature: The Python function header.
    """
    return (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
        "You are a helpful assistant that writes Python functions to solve programming problems.\n"
        "<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
        f"# {task_prompt.strip()}\n{func_signature.strip()}\n"
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
    )

def extract_function_only(text: str) -> str:
    """
    Extract only the Python function definition from a text block.
    - Handles both markdown-style code blocks and plain code.
    - Always returns code starting with 'def' if possible.
    """
    # If code is inside a ```python or ``` block, extract inside code
    if "```python" in text:
        text = text.split("```python")[1].split("```")[0]
    elif "```" in text:
        text = text.split("```")[1].split("```")[0]
    # Extract the part starting with 'def'
    if "def " in text:
        return text[text.index("def "):].strip()
    return text.strip()

def custom_generate_text(
    model, tokenizer, sc_config, input_ids, w_bar=100,
    max_new_tokens=256, temperature=0.2, top_p=0.95, device=None
):
    """
    Custom generation loop for SLICER-based split inference.
    1. "Warm-up" phase: Generate until w_bar tokens without SLICER.
    2. "SLICER" phase: Apply SLICER logic for one step.
    3. "Continuation" phase: Continue generation as usual.
    Arguments:
      - model, tokenizer: Hugging Face model/tokenizer
      - sc_config: SLICER configuration
      - input_ids: tokenized prompt, shape (1, seq_len)
      - w_bar: length threshold to apply SLICER
      - max_new_tokens: max tokens to generate
      - temperature, top_p: decoding params
      - device: torch.device
    Returns:
      - generated_text: Decoded completion (str)
      - out_slicer: SLICER output (if any)
      - applied: True if SLICER was used
    """
    input_ids = input_ids.to(device)
    generated = input_ids.clone()  # Current generated tokens
    input_len = input_ids.shape[1]  # Length of input prompt
    past_key_values = None
    new_tokens = 0
    applied = False   # Whether SLICER is applied
    out_slicer = None # Output object from SLICER call

    def decode_slice():
        # Decode generated tokens after the prompt
        return tokenizer.decode(generated[0, input_len:], skip_special_tokens=True)

    # Phase 1: Initial warm-up tokens (no SLICER)
    while generated.size(1) < w_bar and new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]
        # Decoding: greedy or sampling
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), None, False

    # Phase 2: Apply SLICER once (only if not ended)
    if new_tokens < max_new_tokens:
        with torch.no_grad():
            out = model(
                input_ids=generated,
                past_key_values=None,
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,
                sc_config=sc_config
            )
        logits = out.logits[:, -1, :]
        applied = True
        out_slicer = out
        # Decoding: greedy or sampling
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            return decode_slice(), out_slicer, applied

    # Phase 3: Continue normal generation until end or limit
    past_key_values = None
    while new_tokens < max_new_tokens:
        cur = generated if past_key_values is None else generated[:, -1:]
        with torch.no_grad():
            out = model(
                input_ids=cur,
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
                output_hidden_states=False,
            )
        past_key_values = out.past_key_values
        logits = out.logits[:, -1, :]
        if temperature == 0.0:
            nxt = torch.argmax(logits, dim=-1, keepdim=True)
        else:
            probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
            nxt = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, nxt], dim=-1)
        new_tokens += 1
        if nxt.item() == tokenizer.eos_token_id:
            break

    return decode_slice(), out_slicer, applied

samples = []  # Store generated code completions for each problem
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Iterate over all HumanEval problems, generate solution for each
for i, task in tqdm(enumerate(problems), total=len(problems), desc="Generating"):
    task_id = task["task_id"]
    task_prompt = task["prompt"]  # Problem description
    func_signature = task["canonical_solution"].split('\n')[0]  # First line (function header)

    prompt = build_prompt(task_prompt, func_signature)  # System/user/assistant format
    inputs = tokenizer(prompt, return_tensors="pt").to(device)  # Tokenize and move to device

    # Generate completion using custom SLICER-based generation
    completion, _, _ = custom_generate_text(
        model=model,
        tokenizer=tokenizer,
        sc_config=sc_config,
        input_ids=inputs["input_ids"],
        w_bar=200,             # Use longer warm-up phase (can be tuned)
        max_new_tokens=512,    # Limit completion length
        temperature=0.0,       # Greedy decoding (deterministic)
        top_p=1.0,             # Not used with greedy, but included for completeness
        device=device
    )

    # Post-process: ensure only the function code is kept
    clean_completion = extract_function_only(func_signature + "\n" + completion)
    samples.append({
        "task_id": task_id,
        "completion": clean_completion
    })

# Save all completions to a .jsonl file, one per line
save_path = "llama3_humaneval.jsonl"
with open(save_path, "w") as f:
    for sample in samples:
        f.write(json.dumps(sample) + "\n")

print(f"✅ Save complete: {save_path}")

# Evaluate generated code using HumanEval's functional correctness (pass@1)
results = evaluate_functional_correctness(
    sample_file=save_path,
    k=[1],
    n_workers=4,
    problem_file="./data/human-eval/data/HumanEval.jsonl.gz"
)

print(results)
print(f"🎯 Pass@1 accuracy: {results['pass@1'] * 100:.2f}%")
