from datetime import datetime
from typing import List, Optional

from datasets import load_dataset
import json
import zlib
import pickle
import base64
from enum import Enum
from datetime import datetime
from dataclasses import dataclass

from datasets import load_dataset


class Platform(Enum):
    LEETCODE = "leetcode"
    CODEFORCES = "codeforces"
    ATCODER = "atcoder"


class Difficulty(Enum):
    EASY = "easy"
    MEDIUM = "medium"
    HARD = "hard"


class TestType(Enum):
    STDIN = "stdin"
    FUNCTIONAL = "functional"


@dataclass
class Test:
    input: str
    output: str
    testtype: TestType

    def __post_init__(self):
        self.testtype = TestType(self.testtype)


@dataclass
class CodeGenerationProblem:
    question_title: str
    question_content: str
    platform: Platform
    question_id: str
    contest_id: str
    contest_date: datetime
    starter_code: str
    difficulty: Difficulty
    public_test_cases: list[Test]
    private_test_cases: list[Test]
    metadata: dict

    def __post_init__(self):
        self.platform = Platform(self.platform)
        self.difficulty = Difficulty(self.difficulty)
        self.contest_date = datetime.fromisoformat(self.contest_date)

        self.public_test_cases = json.loads(self.public_test_cases)  # type: ignore
        self.public_test_cases = [Test(**t) for t in self.public_test_cases]

        try:
            self.private_test_cases = json.loads(self.private_test_cases)  # type: ignore
        except:
            self.private_test_cases = json.loads(
                pickle.loads(
                    zlib.decompress(
                        base64.b64decode(self.private_test_cases.encode("utf-8"))  # type: ignore
                    )
                )
            )  # type: ignore
        self.private_test_cases = [Test(**t) for t in self.private_test_cases]

        self.metadata = json.loads(self.metadata)  # type: ignore

    def insert_output(self, output_list: list[str], code_list: list[str]) -> dict:
        return {
            "question_title": self.question_title,
            "question_content": self.question_content,
            "platform": self.platform.value,
            "question_id": self.question_id,
            "contest_id": self.contest_id,
            "contest_date": self.contest_date.isoformat(),
            "starter_code": self.starter_code,
            "difficulty": self.difficulty.value,
            "output_list": output_list,
            "code_list": code_list,
        }

    def insert_output_evaluation(
        self,
        output_list: list[str],
        code_list: list[str],
        graded_list: list[bool],
        **kwargs,
    ) -> dict:
        output = self.insert_output(output_list, code_list)
        output["graded_list"] = graded_list
        output["pass@1"] = graded_list.count(True) / len(graded_list)
        for k, v in kwargs.items():
            output[k] = v
        return output

    def get_evaluation_sample(self):
        return {
            "input_output": json.dumps(
                {
                    "inputs": [
                        t.input
                        for t in self.public_test_cases + self.private_test_cases
                    ],
                    "outputs": [
                        t.output
                        for t in self.public_test_cases + self.private_test_cases
                    ],
                    "fn_name": self.metadata.get("func_name", None),
                }
            ),
        }

def load_code_generation_dataset(
    release_version: str = "release_v2",
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
) -> List[CodeGenerationProblem]:
    """
    Fetches the 'test' split of livecodebench/code_generation_lite at the given
    release tag, wraps each record in CodeGenerationProblem, and optionally
    filters by contest_date between start_date and end_date (YYYY-MM-DD).
    """
    raw_ds = load_dataset(
        "livecodebench/code_generation_lite",
        split="test",
        version_tag=release_version,
        trust_remote_code=True,
    )
    problems = [CodeGenerationProblem(**p) for p in raw_ds]

    if start_date:
        p0 = datetime.strptime(start_date, "%Y-%m-%d")
        problems = [p for p in problems if p.contest_date >= p0]

    if end_date:
        p1 = datetime.strptime(end_date, "%Y-%m-%d")
        problems = [p for p in problems if p.contest_date <= p1]

    print(f"Loaded {len(problems)} problems (release={release_version})")
    return problems

import re

def simplify_inputs(text: str) -> str:
    # this will catch “Input: words = […]” (or any var name) and collapse it to “Input: […]”
    return re.sub(
        r'Input:\s*\w+\s*=\s*(\[[^\]]*\])',
        r'Input: \1',
        text
    )




import re, ast

# ---------- transform_input_block  ----------
def transform_input_block(spec: str) -> str:
    """
    将 narrative 里的 “key = value” / 独立 value 依出现顺序改写成 stdin：
        • 2-D 数组 [[...], ...]   → 每行 “x y z”
        • 1-D 数组 [...]          → 一行 “x y z”
        • 标量 / 字符串           → 一行
    **注意**：只输出 value；键名 (apple, capacity…) 永远不写入结果。
    """
    import ast, re

    events: list[tuple[int, str, object]] = []

    token_re = re.compile(
        r"""
        (?P<kv_array>      \b\w+\s*=\s*\[[^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*\])   # key = [ ... ]
      | (?P<kv_scalar>     \b\w+\s*=\s*(?: "(?:\\.|[^"\\])*" | '(?:\\.|[^'\\])*' | True | False | -?\d+(?:\.\d+)? | \w+))
      | (?P<array>         \[[^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*\])               # standalone [ ... ]
      | (?P<scalar>        "(?:\\.|[^"\\])*" | '(?:\\.|[^'\\])*' | True | False | -?\d+(?:\.\d+)? | \w+)
        """,
        re.X,
    )

    for m in token_re.finditer(spec):
        span_start = m.start()

        # 1) key = [array]
        if m.group("kv_array"):
            lit = m.group("kv_array").split("=", 1)[1].lstrip()
            arr = ast.literal_eval(lit)
            events.append((span_start, "array", arr))
            continue

        # 2) key = scalar
        if m.group("kv_scalar"):
            val = m.group("kv_scalar").split("=", 1)[1].lstrip()
            if val[0] in "\"'":   # strip quotes
                val = val[1:-1]
            events.append((span_start, "scalar", val))
            continue

        # 3) standalone array
        if m.group("array"):
            arr = ast.literal_eval(m.group("array"))
            events.append((span_start, "array", arr))
            continue

        # 4) standalone scalar
        if m.group("scalar"):
            tok = m.group("scalar")
            if tok[0] in "\"'":
                tok = tok[1:-1]
            events.append((span_start, "scalar", tok))

    # ---------- 2) 依原文顺序输出 ----------------------------------------------
    events.sort(key=lambda e: e[0])
    lines: list[str] = []
    for _, kind, val in events:
        if kind == "scalar":
            lines.append(str(val))
        else:                               # array
            if isinstance(val, list) and val and all(isinstance(r, list) for r in val):
                lines.extend(" ".join(map(str, r)) for r in val)
            else:
                lines.append(" ".join(map(str, val)))

    return "\n".join(lines) + "\n"





def transform_spec_block(spec: str) -> str:
    """
    Turn specs like
        "ss = 3, work = ['a','b'], kk = True"
    or
        "[0,1,1,0,2]"
    into:
        3
        a b
        True

    or

        0 1 1 0 2
    (with a trailing newline).
    """
    events = []
    masked = spec

    # 1) bracket-match first [[…]] as before
    start = spec.find('[[')
    if start != -1:
        depth = 0
        end = None
        for i, ch in enumerate(spec[start:], start):
            if ch == '[':   depth += 1
            elif ch == ']':
                depth -= 1
                if depth == 0:
                    end = i
                    break
        if end is not None:
            literal = spec[start:end+1]
            try:
                arr2d = ast.literal_eval(literal)
            except:
                arr2d = []
            events.append((start, 'array2d', arr2d))
            masked = masked[:start] + ' '*(end+1-start) + masked[end+1:]
    else:
        # 2) fallback to first 1D […] literal
        m1 = re.search(r'\[[^\[\]]*\]', spec)
        if m1:
            lit = m1.group(0)
            try:
                raw = ast.literal_eval(lit)
                parts = [str(x) for x in raw]
            except:
                parts = [x.strip() for x in lit.strip('[]').split(',') if x.strip()]
            events.append((m1.start(), 'array1d', parts))
            masked = masked[:m1.start()] + ' '*(m1.end()-m1.start()) + masked[m1.end():]

    # 3) capture any key=value scalars
    for m in re.finditer(
        r'(\w+)\s*=\s*'
        r'( "(?:\\.|[^"\\])*" | \'(?:\\.|[^\'\\])*\' | [^\s,]+ )',
        masked,
        flags=re.VERBOSE
    ):
        val = m.group(2)
        if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")):
            val = val[1:-1]
        events.append((m.start(), 'scalar', val))

    # 4) sort & 5) emit
    events.sort(key=lambda e: e[0])
    out = []
    for _, typ, data in events:
        if typ == 'scalar':
            out.append(data)
        elif typ == 'array1d':
            out.append(" ".join(data))
        else:  # array2d
            for row in data:
                out.append(" ".join(map(str, row)))
    return "\n".join(out) + "\n"


def find_matching(s: str, start: int) -> int:
    """Given s[start] == '[', return the index of its matching ']'."""
    depth = 0
    for i in range(start, len(s)):
        if s[i] == '[':
            depth += 1
        elif s[i] == ']':
            depth -= 1
            if depth == 0:
                return i
    return None

# ---------- 1. replace_output_block ----------
def replace_output_block(text: str) -> str:
    """
    把   Output: …            → Output:\n<多行展开格式>\n
    兼容：标量、1D/2D 列表；标量若带引号一并剥掉。
    """
    def strip_quotes(tok: str) -> str:
        return tok[1:-1] if len(tok) >= 2 and tok[0] in "\"'" and tok[-1] == tok[0] else tok

    out, last = [], 0
    for m in re.finditer(r'Output\s*:', text):
        out.append(text[last:m.end()])   # “…Output:”
        i = m.end()
        while i < len(text) and text[i].isspace():
            out.append(text[i]); i += 1
        if i >= len(text): break

        if text[i] == '[':                              # —— 列表输出 ——
            start, end = i, find_matching(text, i)
            literal = text[start:end+1] if end else "[]"
            try:
                arr = ast.literal_eval(literal)
            except Exception:
                arr = []
            lines = ([" ".join(map(str, r)) for r in arr]
                     if arr and all(isinstance(r, list) for r in arr)
                     else [" ".join(map(str, arr))])
            out.append("\n" + "\n".join(lines) + "\n")
            last = (end or i) + 1

        else:                                           # —— 标量输出 ——
            m2 = re.match(r'(-?\d+|True|False|"(?:\\.|[^"\\])*"|\'(?:\\.|[^\'\\])*\')',
                           text[i:])
            if m2:
                token = strip_quotes(m2.group(0))
                out.append(f"\n{token}\n")
                last = i + len(m2.group(0))
            else:
                last = i
    out.append(text[last:])
    return "".join(out)




import re, ast

def find_matching2(s: str, start: int) -> int:
    """
    Given s[start] == '[', find the matching closing ']' index
    by counting bracket depth.
    """
    depth = 0
    for i in range(start, len(s)):
        if s[i] == '[':
            depth += 1
        elif s[i] == ']':
            depth -= 1
            if depth == 0:
                return i
    return None

def transform_tokens(s: str) -> str:
    """
    Emits one line per token in the order they appear:
      - 2D array [[...],...] → each row “x y z”
      - 1D array  [a,b,c]     → “a b c”
      - Quoted    "foo"       → foo
      - Bare      bar         → bar
    """
    events = []
    masked = s

    # 1) Extract the first nested 2D array [[…]]:
    start2 = s.find('[[')
    if start2 != -1:
        end2 = find_matching2(s, start2)
        if end2 is not None:
            arr_lit = s[start2:end2+1]
            try:
                arr2d = ast.literal_eval(arr_lit)
            except:
                arr2d = []
            # record its position & data
            events.append((start2, 'array2d', arr2d))
            # mask it out so we don't re-match inside
            masked = masked[:start2] + ' '*(end2+1 - start2) + masked[end2+1:]

    # 2) Tokenize what's left into 1D arrays, quoted strings, or bare tokens:
    token_re = re.compile(r'\[[^\]]*\]|"(?:\\.|[^"\\])*"|\'(?:\\.|[^\'\\])*\'|\S+')
    for m in token_re.finditer(masked):
        tok = m.group(0)
        pos = m.start()

        if tok.startswith('[') and tok.endswith(']'):
            # 1D array
            try:
                raw = ast.literal_eval(tok)
                parts = [str(x) for x in raw]
            except:
                parts = [x.strip() for x in tok.strip('[]').split(',') if x.strip()]
            events.append((pos, 'array1d', parts))

        elif (tok.startswith('"') and tok.endswith('"')) or (tok.startswith("'") and tok.endswith("'")):
            # quoted string
            events.append((pos, 'scalar', tok[1:-1]))

        else:
            # bare token (number, boolean, identifier…)
            events.append((pos, 'scalar', tok))

    # 3) Sort by the original position
    events.sort(key=lambda e: e[0])

    out = []
    for _, typ, data in events:
        if typ == 'scalar':
            out.append(str(data))
        elif typ == 'array1d':
            out.append(" ".join(data))
        else:
            for row in data:
                out.append(" ".join(map(str, row)))

    return "\n".join(out) + "\n"

def replace_input_block(text: str) -> str:
    def _repl(m):
        prefix = m.group(1)
        spec    = m.group(2)
        transformed = transform_input_block(spec)
        return f"{prefix}\n{transformed}"
    
    pattern = re.compile(
        r'(Input\s*:\s*)(.*?)(?=\s*(?:Output\s*:|$))',
        flags=re.IGNORECASE|re.DOTALL
    )
    return pattern.sub(_repl, text)



release_tag = "release_v2"       # e.g. "release_v1", "release_v2", etc.
start = None                     # e.g. "2023-01-01" or None
end = None                       # e.g. "2023-12-31" or None
# ————————————————


problems = load_code_generation_dataset(
    release_version=release_tag,
    start_date=start,
    end_date=end,
)

if not problems:
    print("No problems found.")




new_data = []
i = 0


def modify(c):
    c = c.strip()

    c = c.replace("\\n", "\n")

    if not c.endswith("\n"):
        c += "\n"
    
    return c

num_function = 0
for problem in problems:
    if i % 10 == 0:
        print(i)
    data_i = {}
    data_i["dataset"] = "LiveCodeBench"
    if problem.private_test_cases[0].testtype.value == "functional":
        data_i["question"] = replace_output_block(replace_input_block(problem.question_content.strip()))
        data_i["test_input"] = [modify(transform_tokens(t.input))  for t in problem.private_test_cases]
        data_i["test_output"] = [modify(transform_tokens(t.output))  for t in problem.private_test_cases]
        data_i["example_input"] = [modify(transform_tokens(t.input))  for t in problem.public_test_cases]
        data_i["example_output"] = [modify(transform_tokens(t.output))  for t in problem.public_test_cases]
        num_function += 1
    else:
        data_i["question"] = problem.question_content.strip()
        data_i["test_input"] = [modify(t.input)  for t in problem.private_test_cases]
        data_i["test_output"] = [modify(t.output)  for t in problem.private_test_cases]
        data_i["example_input"] = [modify(t.input)  for t in problem.public_test_cases]
        data_i["example_output"] = [modify(t.output)  for t in problem.public_test_cases]
    data_i["solutions"] = None
    data_i["difficulty"] = problem.difficulty.value
    
    data_i["task_id"] = i
    data_i["test_time_limit"] = 8
    data_i["exe_method"] = "stdin"
    i += 1
    new_data.append(data_i)

with open("LiveCodeBench.json", "w", encoding="utf-8") as f:
    json.dump(new_data, f, indent=2, ensure_ascii=False)

