"""MBPP (Mostly Basic Python Problems) dataset handler."""
import re
import signal
from typing import List, Dict
from contextlib import contextmanager
from .base import DatasetHandler


SYSTEM_MESSAGE = (
    "You are a Python programming assistant. Write clean, correct Python code to solve the given problem."
)

USER_TEMPLATE = (
    "{text}\n\n"
    "Your code should pass these tests:\n{tests}\n\n"
    "Think through your solution in <think> </think> tags.\n"
    "Return your final Python code in <answer> </answer> tags, e.g.:\n"
    "<answer>\ndef solution(x):\n    return x + 1\n</answer>"
)


class TimeoutError(Exception):
    pass


@contextmanager
def timeout(seconds=5):
    """Context manager for execution timeout."""
    def handler(signum, frame):
        raise TimeoutError("Code execution timed out")
    
    old_handler = signal.signal(signal.SIGALRM, handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)


def execute_code_with_tests(code: str, test_list: List[str], test_setup_code: str = "", timeout_sec: int = 5) -> bool:
    """Execute code and run test assertions.
    
    Returns True if all tests pass, False otherwise.
    """
    if not code.strip():
        return False
    
    # Build execution environment
    full_code = ""
    if test_setup_code:
        full_code += test_setup_code + "\n"
    full_code += code + "\n"
    
    # Execute code to define functions
    namespace = {}
    try:
        with timeout(timeout_sec):
            exec(full_code, namespace)
    except Exception:
        return False
    
    # Run each test
    for test in test_list:
        try:
            with timeout(timeout_sec):
                exec(test, namespace)
        except AssertionError:
            return False
        except Exception:
            return False
    
    return True


class MBPPHandler(DatasetHandler):
    name = "mbpp"
    default_train_path = "google-research-datasets/mbpp"
    default_test_path = "google-research-datasets/mbpp"
    default_max_tokens = 2048
    
    def load_data(self, path: str, split: str = "train", max_samples: int = None, start_index: int = 0) -> List[Dict]:
        """Load MBPP data from HuggingFace or local disk.
        
        Loads from 'full' subset. When train_path == test_path,
        main code loads all and slices: train=[:train_samples], test=[train_samples:].
        
        Args:
            path: Either HuggingFace dataset name (e.g., 'google-research-datasets/mbpp')
                  or local disk path (e.g., 'data/mbpp_full')
        """
        from datasets import load_dataset, load_from_disk
        import os
        
        # Check if path is local directory
        if os.path.isdir(path):
            # Load from local disk (saved via save_to_disk)
            ds = load_from_disk(path)
        else:
            # Load from HuggingFace
            ds = load_dataset(path, "full")
        
        # Combine all splits: train(374) + validation(90) + test(500) = 964
        all_items = []
        for split_name in ["train", "validation", "test"]:
            if split_name in ds:
                all_items.extend(ds[split_name])
        
        # Limit total samples
        total_needed = min(len(all_items), max_samples) if max_samples else len(all_items)
        
        task_datas = []
        for idx in range(total_needed):
            item = all_items[idx]
            text = item["text"]
            code = item["code"]
            test_list = item["test_list"]
            test_setup_code = item.get("test_setup_code", "")
            
            # Format tests for display
            tests_str = "\n".join(test_list[:3])  # Show first 3 tests
            
            task_datas.append({
                "messages": [
                    {"role": "system", "content": SYSTEM_MESSAGE},
                    {"role": "user", "content": USER_TEMPLATE.format(text=text, tests=tests_str)}
                ],
                "ground_truth": {
                    "code": code,
                    "test_list": test_list,
                    "test_setup_code": test_setup_code,
                },
                "task_id": item["task_id"],
            })
        
        return task_datas
    
    def compute_reward(self, response: str, ground_truth: dict) -> float:
        """Compute reward: 1.0 if all tests pass, 0.0 otherwise."""
        code = self.extract_answer(response)
        if not code:
            return 0.0
        
        passed = execute_code_with_tests(
            code, 
            ground_truth["test_list"],
            ground_truth.get("test_setup_code", "")
        )
        return 1.0 if passed else 0.0
    
    def extract_answer(self, response: str) -> str:
        """Extract code from <answer>...</answer> tags."""
        matches = re.findall(r"<answer>(.*?)</answer>", response, re.DOTALL)
        if not matches:
            # Fallback: try to extract python code block
            code_matches = re.findall(r"```python\n?(.*?)```", response, re.DOTALL)
            if code_matches:
                return code_matches[-1].strip()
            return ""
        return matches[-1].strip()
    
    def extract_answer_for_voting(self, response: str) -> str:
        """For voting, use the extracted code as-is (no normalization to preserve syntax)."""
        return self.extract_answer(response)
    
    def is_answer_correct(self, response: str, ground_truth: dict) -> bool:
        """Check if answer passes all tests."""
        return self.compute_reward(response, ground_truth) > 0
    
    def is_voted_answer_correct(self, voted_answer: str, ground_truth: dict) -> bool:
        """Check if voted code passes all tests."""
        if not voted_answer:
            return False
        # Denormalize: voted_answer is normalized, need to check original
        # Actually for code, we should execute and test
        passed = execute_code_with_tests(
            voted_answer,
            ground_truth["test_list"],
            ground_truth.get("test_setup_code", "")
        )
        return passed
    
    def format_answer_for_check(self, answer: str) -> str:
        """Format answer for checking."""
        return f"<answer>\n{answer}\n</answer>"
