from typing import Dict, List, Optional
import numpy as np
import os, re, sympy
import datasets
from importlib import import_module
from .base import *


QUESTION_FORMAT = """
Q: Write python code to solve the following coding problem that obeys the constraints and passes the example test cases. The output code needs to read from and write to standard IO. Please wrap your code answer using ```:
{question}
A:"""
ANSWER_FORMAT = "{answer}"
SEP = ""
QUESTION_EXAMPLES = []
ANSWER_EXAMPLES = []

PYTHON3_LANGUAGE_ID = 3

IMAGE_TAGS = ["<image>", "[Image]"]

def get_test_cases(item):
    """
    Copied from large language monkey codebase.
    """
    return {
        "input": item["public_tests"]["input"]
        + item["private_tests"]["input"]
        + item["generated_tests"]["input"],
        "output": item["public_tests"]["output"]
        + item["private_tests"]["output"]
        + item["generated_tests"]["output"],
    }

def get_python_solutions(
    item,
    filter_non_ascii: bool = True,
    incorrect_solutions: bool = False,
):
    """
    Copied from large language monkey codebase.
    """
    if incorrect_solutions:
        solution_key = "incorrect_solutions"
    else:
        solution_key = "solutions"

    python_solutions = []
    for i, (solution, lang_id) in enumerate(
        zip(
            item[solution_key]["solution"],
            item[solution_key]["language"],
        )
    ):
        if filter_non_ascii:
            if not solution.isascii():
                continue

        if lang_id == PYTHON3_LANGUAGE_ID:
            python_solutions.append(solution)

    return python_solutions

def has_image_tags(description):
    """
    Copied from large language monkey codebase.
    """
    for tag in IMAGE_TAGS:
        if tag in description:
            return True
    return False

def is_valid_python(snippet):
    try:
        compile(snippet, "<string>", "exec")
        return True
    except SyntaxError:
        return False

def extract_first_code(output_string: str):
    trimmed = output_string.strip()

    # Extracting the first occurrence of content between backticks
    code_match = re.search(r"```(.*?)```", trimmed, re.DOTALL)

    if code_match:
        # Strip leading and trailing whitespace from the extracted code
        code = code_match.group(1).strip()

        # sometimes the block of code is ```python ... ``` instead of ``` ... ```
        # in this case strip the python out

        if code.startswith("python"):
            code = code[len("python") :].strip()

        return code

    if is_valid_python(trimmed):
        return trimmed

    return None

# @TaskRegistry.register("code_contests")
class DataLoader(BaseDataLoader):
    def __init__(self, cfg):
        super().__init__(
            cfg, 
            question_format=QUESTION_FORMAT,
            answer_format=ANSWER_FORMAT,
            sep=SEP, 
            question_examples=QUESTION_EXAMPLES, 
            answer_examples=ANSWER_EXAMPLES,
        )

        # load full data points (used for test cases)
        dataset = datasets.load_dataset(cfg.task.data.name, cfg.task.data.subset, split=cfg.task.data.split, trust_remote_code=True)

        # filter test questions to not have images
        non_image_questions = []
        non_image_data = []
        for question, dp in zip(self.questions, dataset):
            assert question == dp['description']
            if not has_image_tags(question):
                non_image_questions.append(question)
                non_image_data.append(dp)
        self.questions = non_image_questions
        self.non_image_data = non_image_data

        self.num = len(self.questions) if cfg.task.max_samples <= 0 else min(cfg.task.max_samples, len(self.questions))
        self.idxs = self.get_idxs()


    def parse_responses(self, responses, generation_idx: int):
       return [
        {
            'prompt_idx': generation_idx,
            'response': output.text,
            'sum_logprob': output.cumulative_logprob,
            'avg_logprob': output.cumulative_logprob / len(output.token_ids),
            # 'test_cases': get_test_cases(self.non_image_data[idx]),
        } for response in responses for output in response.outputs]

# potentially useful code to get fewshot examples with solutions where none of the descriptions contain images
# few_shot_items_with_solutions = []
# for i, data in enumerate(few_shot_dataset):
#     python_solutions = get_python_solutions(data)
#     data["python_solutions"] = python_solutions
#     if len(python_solutions) > 0 and not has_image_tags(data["description"]):
#         few_shot_items_with_solutions.append(data)

# potentially useful code to get test problems with few shot examples
# no_image_test_dataset = []
# for i, data in enumerate(test_dataset):
#     if has_image_tags(data["description"]):
#         continue
#     few_shot_items = random.sample(
#         few_shot_items_with_solutions, config.num_few_shot
#     )
#     data["few_shot_items"] = few_shot_items
#     no_image_test_dataset.append(data)
