"""
dataset.py — Dataset generator for the consistency loss experiment.

Generates Python code snippets (sorting, searching, string ops, math),
extracts ground-truth claims (time_complexity, space_complexity, correctness),
and builds (code_snippet, mismatched_explanation, ground_truth_claims) triples.
"""

import random
import re
import ast
import textwrap
from dataclasses import dataclass, field
from typing import List, Tuple, Optional

# ──────────────────────────────────────────────────────────────────────────────
# 1. Complexity label enums
# ──────────────────────────────────────────────────────────────────────────────

TIME_BINS  = ["O(1)", "O(n)", "O(n^2)"]
SPACE_BINS = ["O(1)", "O(n)", "O(n^2)"]

TIME2IDX  = {v: i for i, v in enumerate(TIME_BINS)}
SPACE2IDX = {v: i for i, v in enumerate(SPACE_BINS)}

# ──────────────────────────────────────────────────────────────────────────────
# 2. Template bank
# Each entry: (code_template, explanation_template, time, space, correct)
# correct=True means the function body is correct, False means buggy.
# ──────────────────────────────────────────────────────────────────────────────

@dataclass
class FunctionTemplate:
    name: str
    code: str                    # Python source (string)
    explanation: str             # Ground-truth natural-language explanation
    time_complexity: str         # TIME_BINS member
    space_complexity: str        # SPACE_BINS member
    correct: bool                # passes basic correctness tests?
    test_fn: Optional[str] = None  # name of a test helper (optional)


def _make_templates() -> List[FunctionTemplate]:
    templates = []

    # ── O(1) time / O(1) space ───────────────────────────────────────────────

    templates.append(FunctionTemplate(
        name="get_first",
        code=textwrap.dedent("""\
            def get_first(lst):
                if not lst:
                    return None
                return lst[0]
        """),
        explanation="Returns the first element of a list in constant time and constant space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="is_even",
        code=textwrap.dedent("""\
            def is_even(n):
                return n % 2 == 0
        """),
        explanation="Checks whether an integer is even using the modulo operator. Runs in O(1) time and O(1) space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="absolute_value",
        code=textwrap.dedent("""\
            def absolute_value(x):
                if x < 0:
                    return -x
                return x
        """),
        explanation="Computes the absolute value of a number in constant time and space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="swap_buggy",
        code=textwrap.dedent("""\
            def swap(a, b):
                a = b
                b = a   # bug: both become original b
                return a, b
        """),
        explanation="Attempts to swap two values but contains a logic bug, returning incorrect results.",
        time_complexity="O(1)", space_complexity="O(1)", correct=False,
    ))

    templates.append(FunctionTemplate(
        name="max_of_two",
        code=textwrap.dedent("""\
            def max_of_two(a, b):
                return a if a >= b else b
        """),
        explanation="Returns the larger of two numbers using a conditional expression. O(1) time and space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="celsius_to_fahrenheit",
        code=textwrap.dedent("""\
            def celsius_to_fahrenheit(c):
                return c * 9 / 5 + 32
        """),
        explanation="Converts Celsius to Fahrenheit in constant time and space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="is_palindrome_char",
        code=textwrap.dedent("""\
            def is_palindrome_char(ch):
                return ch == ch  # trivially true
        """),
        explanation="A trivial stub that always returns True. O(1) time and space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=False,
    ))

    templates.append(FunctionTemplate(
        name="clamp",
        code=textwrap.dedent("""\
            def clamp(x, lo, hi):
                return max(lo, min(x, hi))
        """),
        explanation="Clamps a value between lo and hi bounds in constant time and space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="sign",
        code=textwrap.dedent("""\
            def sign(x):
                if x > 0: return 1
                if x < 0: return -1
                return 0
        """),
        explanation="Returns the sign of a number (-1, 0, or 1) in O(1) time and space.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="dict_lookup",
        code=textwrap.dedent("""\
            def dict_lookup(d, key):
                return d.get(key, None)
        """),
        explanation="Performs O(1) average-case dictionary lookup by key.",
        time_complexity="O(1)", space_complexity="O(1)", correct=True,
    ))

    # ── O(n) time / O(1) space ───────────────────────────────────────────────

    templates.append(FunctionTemplate(
        name="linear_sum",
        code=textwrap.dedent("""\
            def linear_sum(lst):
                total = 0
                for x in lst:
                    total += x
                return total
        """),
        explanation="Computes the sum of a list with a single pass. Time complexity O(n), space O(1).",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="linear_max",
        code=textwrap.dedent("""\
            def linear_max(lst):
                if not lst:
                    return None
                m = lst[0]
                for x in lst[1:]:
                    if x > m:
                        m = x
                return m
        """),
        explanation="Finds the maximum element with a single linear scan. O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="linear_search",
        code=textwrap.dedent("""\
            def linear_search(lst, target):
                for i, x in enumerate(lst):
                    if x == target:
                        return i
                return -1
        """),
        explanation="Searches a list sequentially for a target value. O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="count_occurrences",
        code=textwrap.dedent("""\
            def count_occurrences(lst, val):
                count = 0
                for x in lst:
                    if x == val:
                        count += 1
                return count
        """),
        explanation="Counts occurrences of val in lst. Single-pass O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="reverse_list_buggy",
        code=textwrap.dedent("""\
            def reverse_list(lst):
                for i in range(len(lst) // 2):
                    lst[i], lst[i] = lst[i], lst[i]  # bug: wrong index
                return lst
        """),
        explanation="Attempts to reverse a list in-place but has a bug: it swaps an element with itself.",
        time_complexity="O(n)", space_complexity="O(1)", correct=False,
    ))

    templates.append(FunctionTemplate(
        name="string_contains",
        code=textwrap.dedent("""\
            def string_contains(s, ch):
                for c in s:
                    if c == ch:
                        return True
                return False
        """),
        explanation="Checks if a string contains a character via linear scan. O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="compute_mean",
        code=textwrap.dedent("""\
            def compute_mean(lst):
                if not lst:
                    return 0.0
                return sum(lst) / len(lst)
        """),
        explanation="Computes arithmetic mean. The built-in sum is O(n); overall O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="is_sorted",
        code=textwrap.dedent("""\
            def is_sorted(lst):
                for i in range(len(lst) - 1):
                    if lst[i] > lst[i + 1]:
                        return False
                return True
        """),
        explanation="Checks whether a list is sorted in ascending order. O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="factorial_loop",
        code=textwrap.dedent("""\
            def factorial(n):
                result = 1
                for i in range(2, n + 1):
                    result *= i
                return result
        """),
        explanation="Computes n! iteratively. O(n) time, O(1) space.",
        time_complexity="O(n)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="string_reverse",
        code=textwrap.dedent("""\
            def string_reverse(s):
                result = ''
                for ch in s:
                    result = ch + result
                return result
        """),
        explanation="Reverses a string character by character. O(n) time, O(n) space due to string accumulation.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    # ── O(n) time / O(n) space ───────────────────────────────────────────────

    templates.append(FunctionTemplate(
        name="build_freq_map",
        code=textwrap.dedent("""\
            def build_freq_map(lst):
                freq = {}
                for x in lst:
                    freq[x] = freq.get(x, 0) + 1
                return freq
        """),
        explanation="Builds a frequency map in O(n) time and O(n) space.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="copy_list",
        code=textwrap.dedent("""\
            def copy_list(lst):
                return lst[:]
        """),
        explanation="Creates a shallow copy of a list. O(n) time and O(n) space.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="flatten_one_level",
        code=textwrap.dedent("""\
            def flatten_one_level(lst):
                result = []
                for sub in lst:
                    result.extend(sub)
                return result
        """),
        explanation="Flattens one level of nesting in a list of lists. O(n) time and O(n) space.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="unique_elements",
        code=textwrap.dedent("""\
            def unique_elements(lst):
                seen = set()
                result = []
                for x in lst:
                    if x not in seen:
                        seen.add(x)
                        result.append(x)
                return result
        """),
        explanation="Returns unique elements preserving order. O(n) time and O(n) space.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="prefix_sums",
        code=textwrap.dedent("""\
            def prefix_sums(lst):
                ps = [0] * (len(lst) + 1)
                for i, x in enumerate(lst):
                    ps[i + 1] = ps[i] + x
                return ps
        """),
        explanation="Computes prefix sum array. O(n) time and O(n) space.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="zip_two_lists",
        code=textwrap.dedent("""\
            def zip_two_lists(a, b):
                result = []
                for x, y in zip(a, b):
                    result.append((x, y))
                return result
        """),
        explanation="Zips two lists into pairs. O(n) time and O(n) space.",
        time_complexity="O(n)", space_complexity="O(n)", correct=True,
    ))

    # ── O(n²) time / O(1) space ──────────────────────────────────────────────

    templates.append(FunctionTemplate(
        name="bubble_sort",
        code=textwrap.dedent("""\
            def bubble_sort(lst):
                n = len(lst)
                for i in range(n):
                    for j in range(n - i - 1):
                        if lst[j] > lst[j + 1]:
                            lst[j], lst[j + 1] = lst[j + 1], lst[j]
                return lst
        """),
        explanation="Bubble sort: O(n^2) time with nested loops, O(1) auxiliary space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="selection_sort",
        code=textwrap.dedent("""\
            def selection_sort(lst):
                for i in range(len(lst)):
                    min_idx = i
                    for j in range(i + 1, len(lst)):
                        if lst[j] < lst[min_idx]:
                            min_idx = j
                    lst[i], lst[min_idx] = lst[min_idx], lst[i]
                return lst
        """),
        explanation="Selection sort: selects the minimum in O(n^2) nested passes, O(1) space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="insertion_sort",
        code=textwrap.dedent("""\
            def insertion_sort(lst):
                for i in range(1, len(lst)):
                    key = lst[i]
                    j = i - 1
                    while j >= 0 and lst[j] > key:
                        lst[j + 1] = lst[j]
                        j -= 1
                    lst[j + 1] = key
                return lst
        """),
        explanation="Insertion sort: O(n^2) worst-case time, O(1) space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="all_pairs_sum",
        code=textwrap.dedent("""\
            def all_pairs_sum(lst):
                total = 0
                for i in range(len(lst)):
                    for j in range(len(lst)):
                        total += lst[i] + lst[j]
                return total
        """),
        explanation="Sums all pairs from the list. Double nested loop gives O(n^2) time, O(1) space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="matrix_multiply_buggy",
        code=textwrap.dedent("""\
            def matrix_multiply_2x2(A, B):
                C = [[0, 0], [0, 0]]
                for i in range(2):
                    for j in range(2):
                        C[i][j] = A[i][0] * B[0][j]  # bug: missing second term
                return C
        """),
        explanation="Attempts 2x2 matrix multiplication but is buggy (missing accumulation). O(n^2) time, O(1) auxiliary space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=False,
    ))

    templates.append(FunctionTemplate(
        name="check_all_pairs_equal",
        code=textwrap.dedent("""\
            def check_all_pairs_equal(lst):
                for i in range(len(lst)):
                    for j in range(i + 1, len(lst)):
                        if lst[i] != lst[j]:
                            return False
                return True
        """),
        explanation="Checks if all elements are equal via pairwise comparison. O(n^2) time, O(1) space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="has_duplicate_quadratic",
        code=textwrap.dedent("""\
            def has_duplicate_quadratic(lst):
                for i in range(len(lst)):
                    for j in range(i + 1, len(lst)):
                        if lst[i] == lst[j]:
                            return True
                return False
        """),
        explanation="Detects duplicates with nested loops. O(n^2) time, O(1) space.",
        time_complexity="O(n^2)", space_complexity="O(1)", correct=True,
    ))

    # ── O(n²) time / O(n) space ──────────────────────────────────────────────

    templates.append(FunctionTemplate(
        name="generate_all_pairs",
        code=textwrap.dedent("""\
            def generate_all_pairs(lst):
                pairs = []
                for i in range(len(lst)):
                    for j in range(len(lst)):
                        pairs.append((lst[i], lst[j]))
                return pairs
        """),
        explanation="Generates all ordered pairs. O(n^2) time and O(n^2) space.",
        time_complexity="O(n^2)", space_complexity="O(n^2)", correct=True,
    ))

    templates.append(FunctionTemplate(
        name="naive_convolution",
        code=textwrap.dedent("""\
            def naive_convolution(a, b):
                n = len(a) + len(b) - 1
                result = [0] * n
                for i in range(len(a)):
                    for j in range(len(b)):
                        result[i + j] += a[i] * b[j]
                return result
        """),
        explanation="Naive polynomial/array convolution. O(n^2) time, O(n) output space.",
        time_complexity="O(n^2)", space_complexity="O(n)", correct=True,
    ))

    return templates


TEMPLATES = _make_templates()


# ──────────────────────────────────────────────────────────────────────────────
# 3. Ground-truth oracle (static pattern matching + lightweight exec)
# ──────────────────────────────────────────────────────────────────────────────

def oracle_verify(template: FunctionTemplate, test_input=None) -> dict:
    """
    Returns ground-truth claims dict:
        time_complexity_idx  : int  (0=O(1), 1=O(n), 2=O(n^2))
        space_complexity_idx : int
        correctness          : int  (0 or 1)
    Uses the template's declared complexity and correctness directly
    (rule-based: derived from code structure analysis embedded in templates).
    """
    return {
        "time_complexity":      template.time_complexity,
        "space_complexity":     template.space_complexity,
        "correctness":          int(template.correct),
        "time_complexity_idx":  TIME2IDX[template.time_complexity],
        "space_complexity_idx": SPACE2IDX[template.space_complexity],
    }


# ──────────────────────────────────────────────────────────────────────────────
# 4. Dataset construction
# ──────────────────────────────────────────────────────────────────────────────

@dataclass
class Example:
    idx: int
    code_snippet: str
    # Ground-truth explanation (for reference; not given to model)
    true_explanation: str
    # Mismatched explanation (from a randomly chosen *different* template)
    mismatched_explanation: str
    # Ground-truth claims
    time_complexity: str
    space_complexity: str
    correctness: int
    time_complexity_idx: int
    space_complexity_idx: int
    template_name: str


def build_dataset(n: int = 3000, seed: int = 42) -> List[Example]:
    """
    Build n examples.
    - Sample templates (with replacement to reach n).
    - For each example, pair the code with a *different* template's explanation
      (random permutation ensures mismatch).
    """
    rng = random.Random(seed)
    templates = TEMPLATES

    # Sample n template indices
    indices = [rng.randint(0, len(templates) - 1) for _ in range(n)]

    # Permute explanation source indices so they're mismatched
    expl_indices = indices[:]
    # Fisher-Yates shuffle ensuring no element stays in place (derangement approx)
    for _ in range(5):
        rng.shuffle(expl_indices)
    # Fix any accidental self-matches
    for i in range(n):
        if expl_indices[i] == indices[i] and n > 1:
            j = (i + rng.randint(1, n - 1)) % n
            expl_indices[i], expl_indices[j] = expl_indices[j], expl_indices[i]

    examples = []
    for k in range(n):
        t = templates[indices[k]]
        t_expl = templates[expl_indices[k]]
        claims = oracle_verify(t)
        ex = Example(
            idx=k,
            code_snippet=t.code,
            true_explanation=t.explanation,
            mismatched_explanation=t_expl.explanation,
            time_complexity=claims["time_complexity"],
            space_complexity=claims["space_complexity"],
            correctness=claims["correctness"],
            time_complexity_idx=claims["time_complexity_idx"],
            space_complexity_idx=claims["space_complexity_idx"],
            template_name=t.name,
        )
        examples.append(ex)
    return examples


def split_dataset(examples: List[Example], val_size: int = 500, seed: int = 42):
    """Split into train / val. Returns (train, val)."""
    rng = random.Random(seed)
    shuffled = examples[:]
    rng.shuffle(shuffled)
    val = shuffled[:val_size]
    train = shuffled[val_size:]
    return train, val


# ──────────────────────────────────────────────────────────────────────────────
# 5. Tokenizer (simple character/word-level for self-contained operation)
# ──────────────────────────────────────────────────────────────────────────────

CLAIM_OPEN  = "<claim>"
CLAIM_CLOSE = "</claim>"
SEP_TOKEN   = "<sep>"
PAD_TOKEN   = "<pad>"
BOS_TOKEN   = "<bos>"
EOS_TOKEN   = "<eos>"
UNK_TOKEN   = "<unk>"

SPECIAL_TOKENS = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN,
                  SEP_TOKEN, CLAIM_OPEN, CLAIM_CLOSE]

# Claim value tokens
CLAIM_TOKENS = [
    "time_complexity=O(1)",  "time_complexity=O(n)",  "time_complexity=O(n^2)",
    "space_complexity=O(1)", "space_complexity=O(n)", "space_complexity=O(n^2)",
    "correctness=0", "correctness=1",
]


class SimpleTokenizer:
    """
    Word/sub-word tokenizer built from the dataset vocabulary.
    Splits on whitespace + punctuation, keeps special tokens intact.
    """

    def __init__(self):
        self.token2id = {}
        self.id2token = {}
        for tok in SPECIAL_TOKENS + CLAIM_TOKENS:
            self._add(tok)

    def _add(self, tok: str) -> int:
        if tok not in self.token2id:
            idx = len(self.token2id)
            self.token2id[tok] = idx
            self.id2token[idx] = tok
        return self.token2id[tok]

    def _split(self, text: str) -> List[str]:
        # Keep special tokens intact; split rest on whitespace
        pattern = r'|'.join(re.escape(t) for t in
                            sorted(list(self.token2id.keys()) + CLAIM_TOKENS, key=len, reverse=True))
        parts = []
        pos = 0
        for m in re.finditer(pattern, text):
            if pos < m.start():
                for w in text[pos:m.start()].split():
                    parts.append(w)
            parts.append(m.group())
            pos = m.end()
        if pos < len(text):
            for w in text[pos:].split():
                parts.append(w)
        return [p for p in parts if p]

    def build_vocab(self, texts: List[str]):
        for text in texts:
            for tok in self._split(text):
                self._add(tok)

    def encode(self, text: str) -> List[int]:
        unk_id = self.token2id[UNK_TOKEN]
        return [self.token2id.get(t, unk_id) for t in self._split(text)]

    def decode(self, ids: List[int]) -> str:
        return ' '.join(self.id2token.get(i, UNK_TOKEN) for i in ids)

    @property
    def vocab_size(self) -> int:
        return len(self.token2id)

    @property
    def pad_id(self) -> int:
        return self.token2id[PAD_TOKEN]

    @property
    def bos_id(self) -> int:
        return self.token2id[BOS_TOKEN]

    @property
    def eos_id(self) -> int:
        return self.token2id[EOS_TOKEN]

    @property
    def sep_id(self) -> int:
        return self.token2id[SEP_TOKEN]

    @property
    def claim_open_id(self) -> int:
        return self.token2id[CLAIM_OPEN]

    @property
    def claim_close_id(self) -> int:
        return self.token2id[CLAIM_CLOSE]


def make_target_sequence(example: Example) -> str:
    """
    Build the target sequence for an example:
    <bos> [code] <sep> [explanation] <claim>time_complexity=X</claim>
                                     <claim>space_complexity=Y</claim>
                                     <claim>correctness=Z</claim> <eos>
    Explanation tokens come BEFORE claim tokens (causal masking: expl cannot see claims).
    """
    claims_str = (
        f"{CLAIM_OPEN}time_complexity={example.time_complexity}{CLAIM_CLOSE}"
        f"{CLAIM_OPEN}space_complexity={example.space_complexity}{CLAIM_CLOSE}"
        f"{CLAIM_OPEN}correctness={example.correctness}{CLAIM_CLOSE}"
    )
    return (
        f"{BOS_TOKEN} {example.code_snippet.strip()} {SEP_TOKEN} "
        f"{example.mismatched_explanation} {claims_str} {EOS_TOKEN}"
    )


def build_tokenizer(examples: List[Example]) -> SimpleTokenizer:
    tok = SimpleTokenizer()
    texts = []
    for ex in examples:
        texts.append(make_target_sequence(ex))
    tok.build_vocab(texts)
    return tok


if __name__ == "__main__":
    examples = build_dataset(n=20, seed=0)
    for ex in examples[:3]:
        print(f"--- {ex.template_name} ---")
        print(f"Time: {ex.time_complexity}  Space: {ex.space_complexity}  Correct: {ex.correctness}")
        print(f"Mismatched expl: {ex.mismatched_explanation[:80]}")
        print()
