import wandb
import torch
from transformers import TrainerCallback
from datasets import load_dataset

# HuggingFace evaluate
import evaluate as hf_evaluate
from typing import Union

from components.preprocessor import DataLoader

import os
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

# Load the code_eval metric once
try:
    pass_at_k = hf_evaluate.load("code_eval")

    # quick sanity check
    test_cases = ["assert add(2, 3)==5"]
    candidates = [["def add(a,b): return a*b"]]
    _ = pass_at_k.compute(
        references=test_cases,
        predictions=candidates,
        k=[1]
    )
except Exception as e:
    raise e


# ============================================================
#               Utility: pass@1 scoring
# ============================================================

def pass_at_1(
    references: Union[list[str], list[list[str]]],
    predictions: Union[list[str], list[list[str]]]
) -> float:
    """
    Compute pass@1 using HF code_eval.
    Shape required:
        references  = list[list[str]] (each list = test_list)
        predictions = list[list[str]] (each list = [generated_code])
    """
    if isinstance(references, str):
        references = [references]
    if isinstance(predictions[0], str):
        predictions = [[p] for p in predictions]

    print("References", references)
    print("Predictions", predictions)

    result = pass_at_k.compute(
        references=references,
        predictions=predictions,
        k=[1]
    )
    return result[0]["pass@1"]


# ============================================================
#        Utility: Extract Python code from model text
# ============================================================

def extract_code_blocks(text: str) -> str:
    text = text.strip()

    # 1. If starts with ```python → take everything until the next ```
    if text.startswith("```python"):
        end = text.find("```", len("```python"))
        if end != -1:
            return text[len("```python"):end].strip()
        return text[len("```python"):].strip()

    # 2. If starts with ``` but not python → take until next ```
    if text.startswith("```"):
        end = text.find("```", 3)
        if end != -1:
            return text[3:end].strip()
        return text[3:].strip()

    # 3. If doesn’t start with ```
    text = text.replace("```python", "```")
    count_backticks = text.count("```")

    # 4. If count is odd → take everything until the last ```
    if count_backticks % 2 == 1 and count_backticks > 0:
        last = text.rfind("```")
        return text[:last].strip()

    # 5. If count is even and >= 2 → take first complete block between ```
    if count_backticks >= 2:
        first = text.find("```")
        second = text.find("```", first + 3)
        return text[first + 3:second].strip()

    return text


# ============================================================
#                     MBPP Callback
# ============================================================

class MBPPEvalCallback(TrainerCallback):

    def __init__(self, tokenizer, dataset_args, k=5, max_samples=10):
        self.tokenizer = tokenizer
        self.dataset_args = dataset_args
        self._eval_counter = -1
        self._k = k
        self.max_samples = max_samples
        self._ds = None

        # create DataLoader helper (for blocks)
        self.loader = DataLoader(dataset_args, tokenizer)
        
        # Get mask token id
        self.mask_token_id = tokenizer.mask_token_id
        if self.mask_token_id is None:
            # Fallback if mask token doesn't exist
            raise ValueError("Tokenizer does not have a mask token. Please use a tokenizer with mask_token_id defined.")
        
    def _load_subset(self):
        if self._ds is None:
            ds = load_dataset("google-research-datasets/mbpp", split="validation")
            ds = ds.shuffle(seed=42).select(range(min(self.max_samples, len(ds))))
            self._ds = ds
        return self._ds

    @torch.no_grad()
    def on_evaluate(self, args, state, control, **kwargs):
        print("=============================")
        print("RUNNING EVALUATION")
        print("=============================")

        model = kwargs["model"]
        model.eval()
        tokenizer = self.tokenizer
        loader = self.loader

        self._eval_counter += 1
        if self._eval_counter % self._k != 0:
            return

        ds = self._load_subset()    

        references = []
        predictions = []

        for ex in ds:
            tests = "\n".join(ex["test_list"])
            prompt_text = ex["text"] + f" Don't write any comment or docstring and give me ONLY the function. It should pass the following test case {ex['test_list'][0]}."
            ground_truth = ex["code"] + "<|mask|>" * 50 # Add mask token to allow for more tokens

            # 1. Tokenize prompt
            msgs = [{"role": "user", "content": prompt_text}]
            chat_str = tokenizer.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + "```python\n"
            prompt_ids = tokenizer(chat_str, add_special_tokens=False)["input_ids"]
            prompt_len = len(prompt_ids)

            # 2. Append 512 mask tokens
            mask_ids = [self.mask_token_id] * 512
            input_ids = prompt_ids + mask_ids

            # 3. Compute block ranges on ground truth using DataLoader function
            _, block_ranges = loader.split_in_blocks(ground_truth, tokenizer, prompt_len)

            # Convert to tensor and move to device
            device = next(model.parameters()).device
            input_ids_tensor = torch.tensor([input_ids], device=device)
            
            # 4. Feed to model with block ranges
            print("\n\nPrompt:\n", chat_str, flush=True)
            decoded = model.dummy_generate(
                input_ids_tensor=input_ids_tensor,
                block_ranges=[block_ranges],
                prompt_len=prompt_len,
            )

            # strip prompt
            try:
                idx = decoded.index(prompt_text)
                generated = decoded[idx + len(prompt_text):].strip()
            except:
                generated = decoded

            # clean
            code = extract_code_blocks(generated)

            references.append(tests)
            predictions.append([code])

        score = pass_at_1(references, predictions)
        print("MBPP pass@1:", score, flush=True)

        wandb.log({"eval/mbpp_pass@1": score})

        model.train()