from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
import itertools
import re
import math
import torch
import heapq
import datasets
from transformers import AutoTokenizer
import json
import os
import gc
from datasets import load_dataset


class AddDataset(Dataset):
    def __init__(self, length, max_len):
        self.length = length
        self.max_len = max_len
        self.data = []
        self.labels = []

        for _ in range(length):
            len1 = random.randint(1, max_len)
            len2 = random.randint(1, max_len)
            num1 = ''.join(str(random.randint(0, 9)) for _ in range(len1))
            num2 = ''.join(str(random.randint(0, 9)) for _ in range(len2))
            sum = ' '.join(str(int(num1) + int(num2)))
            num1 = ' '.join(num1)
            num2 = ' '.join(num2)
            self.data.append(f"{num1[::-1]} + {num2[::-1]}")
            self.labels.append(sum[::-1])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {'instruction' : "Compute addition.", "input":self.data[idx], "output": str(self.labels[idx])}

class ParityDatasetPositionLabeled(Dataset):
    def __init__(self, length, seq_len):
        self.length = length
        self.seq_len = seq_len
        self.data = []
        self.labels = []

        for _ in range(length):
            len = random.randint(1, seq_len)
            binary_str = ''.join(str(torch.randint(2, size=()).item()) for _ in range(len)) #len
            #for _ in range(seq_len - len):
            #  binary_str = binary_str + "0"
            data = ""
            for i in range(len):
            #for i in range(seq_len):
              data = data + str(i) + ": " + binary_str[i] + " "
            self.data.append(data)
            self.labels.append(sum(int(bit) for bit in binary_str) % 2)
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {'instruction' : "Output the parity of this sequence.", "input":self.data[idx], "output": str(self.labels[idx])}

class MultistepArithmeticTask:
    """A simple text-based arithmetic task of variable difficulty."""

    def __init__(
        self,
        num_trials=100,
        verbose=False,
        operations=["+", "-"],
        numbers=list(range(-9,10)),#list(range(-9, 10)),
        with_spaces=True,
        depth_level_list=([], ),#([], [2], [2, 2], [2, 2, 2]),
        lengths=list(range(2, 4))#list(range(2, 5)),
    ):
        """Defines a simple arithmetic task of varying difficulty levels.

        Args:
            seed: Numpy random seed.
            num_trials: Number of examples to evaluate with each configuration of depth_levels and length.
            verbose: Print each example during evaluation.
            operations: List of operations from which one samples randomly, for example ['+', '-', '*'].
            numbers: List of numbers from which one samples randomly to appear in the question string.
            with_spaces: Whether the input string contains spaces between numbers and operations.
            depth_level_list: How many parenthesis are in each level of the tree-like structure.
                For example [2,3] means the string contains two main parenthesis, each with three subparenthesis.
            lengths: number of numbers to be operated on in the inner-most parenthesis level.
        Questions are of the form: '(2 * 4 - 2) = ' and '((1 + 1) + (3 + 2)) ='.
        Answers: '4' and '7'  etc.
        """
        self.num_trials = num_trials
        self.verbose = verbose
        self.operations = operations
        self.numbers = [str(x) for x in numbers]
        self.with_spaces = with_spaces
        self.depth_level_list = depth_level_list
        self.lengths = lengths

    def generate_content(self, length, rng = None):
        substr = "("
        space = " " if self.with_spaces else ""
        for _ in range(length - 1):
            substr = (
                substr
                + random.choice(self.numbers)
                + space
                + random.choice(self.operations)
                + space
            )
        substr = substr + random.choice(self.numbers) + ")"
        return substr

    def generate_subparenthesis(self, subparenthesis, length):
        if not subparenthesis:
            return self.generate_content(length)
        else:
            space = " " if self.with_spaces else ""
            substr = "("
            for _ in range(subparenthesis[0] - 1):
                substr = (
                    substr
                    + self.generate_subparenthesis(subparenthesis[1:], length)
                    + space
                    + random.choice(self.operations)
                    + space
                )
            substr = (
                substr + self.generate_subparenthesis(subparenthesis[1:], length) + ")"
            )
            return substr

    def generate_string(self, depth_levels, length):
        return self.generate_subparenthesis(depth_levels, length)

class Multiple_Arithmetic_Dataset(Dataset):
    def __init__(self, trials):
        task = MultistepArithmeticTask()
        self.data = []
        self.labels = []
        for _ in range(trials):
            depth_levels = random.choice(task.depth_level_list)
            length = random.choice(task.lengths)
            input = task.generate_string(depth_levels=depth_levels, length=length)
            problem = input + " = "
            self.data.append(problem)
            self.labels.append(" ".join(str(eval(input) % 10)))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {'instruction' : "Evaluate the expression in modulo 10.", "input":self.data[idx], "output": str(self.labels[idx])}

def generate_instructions(max_steps_per_instruction, max_total_instructions):
    instructions = []
    num = 0
    while num < max_total_instructions:
        steps = random.randint(1, max_steps_per_instruction)
        direction = random.choices(
            ["Take {} steps".format(steps), "Turn right", "Turn left", "Turn around"],
            weights=[0.52, 0.16, 0.16, 0.16],
            k=1
        )[0]
        if direction.startswith("Take"):
          num = num + 1
        instructions.append(direction)
    return instructions

def follow_instructions(instructions):
    x, y = 0, 0
    direction = 0  # 0: up, 1: right, 2: down, 3: left
    for instr in instructions:
        if instr.startswith("Take"):
            steps = int(instr.split()[1])
            if direction == 0:
                y += steps
            elif direction == 1:
                x += steps
            elif direction == 2:
                y -= steps
            elif direction == 3:
                x -= steps
        elif instr == "Turn right":
            direction = (direction + 1) % 4
        elif instr == "Turn left":
            direction = (direction - 1) % 4
        elif instr == "Turn around":
            direction = (direction + 2) % 4
    # print(f"{x}, {y}")
    return x % 2 == 0 and y % 2 == 0


class NavigateDataset(Dataset):
    def __init__(self, length, max_steps_per_instruction, max_total_instructions):
        self.length = length
        self.data = []
        self.labels = []
        target_per_label = math.ceil(length/(max_total_instructions - 1)/2)
        # print(target_per_label)
        data = []
        for num_total_instructions in range(2, max_total_instructions + 1):
          data_true = []
          data_false = []

          while len(data_true) < target_per_label or len(data_false) < target_per_label:
            # print(f"True: {len(data_true)}")
            # print(f"False: {len(data_false)}")
            instructions = generate_instructions(max_steps_per_instruction, num_total_instructions)
            question = " ".join(instructions)
            answer = follow_instructions(instructions)

            if answer and len(data_true) < target_per_label:
                data_true.append({"Q": question, "A": answer})
            elif not answer and len(data_false) < target_per_label:
                data_false.append({"Q": question, "A": answer})

          data.extend( data_true + data_false)
        random.shuffle(data)
        data = data[:length]
        for sample in data:
            self.data.append(sample["Q"])
            self.labels.append(str(sample["A"]))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {'instruction' : "If you follow these instructions, do you return to the starting point?", "input":self.data[idx], "output": str(self.labels[idx])}

class Multiplication_Dataset(Dataset):
    def __init__(self, trials):
        task = MultistepArithmeticTask(operations = ["*"], numbers=list(range(0,11)))
        self.data = []
        self.labels = []
        for depth_levels, length in itertools.product(
            task.depth_level_list, task.lengths
        ):
          for _ in range(trials):
            input = task.generate_string(depth_levels=depth_levels, length=length)
            problem = input + " = "
            self.data.append(problem)
            self.labels.append(" ".join(str(eval(input) % 11)))
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {'instruction' : "Evaluate the expression (python version) in modulo 11.", "input":self.data[idx], "output": str(self.labels[idx])}




class RandomPruferExpressionDataset(Dataset):
    """
    A dataset that:
      - For each example, chooses a random n in [min_n, max_n].
      - Generates a random labeled tree on {0,1,...,n-1} via a random Prüfer code (0-based).
      - Assigns each node:
          * A numeric label (from `numbers`)
          * A random sign in {+, -}
          * A random operator in {+, *} (used if the node has >=1 child)
      - Interprets the tree as rooted at node 0,
        then builds an expression string from that root.
      - Evaluates the expression (mod 10).
    """

    def __init__(
        self,
        size=1000,         # how many expressions to generate
        min_n=1,           # minimum number of nodes
        max_n=6,           # maximum number of nodes
        numbers=range(1, 10)  # possible numeric labels
    ):
        """
        Args:
            size (int): number of expressions (trees) to generate
            min_n (int): minimum number of nodes
            max_n (int): maximum number of nodes
            numbers (iterable): possible numeric labels for each node
        """
        super().__init__()
        self.size = size
        self.min_n = min_n
        self.max_n = max_n
        self.numbers = list(numbers)

        self.data = []
        self.labels = []

        for _ in range(self.size):
            # 1) Sample a random n
            n = random.randint(self.min_n, self.max_n)

            # Edge case: if n == 1, the tree is just a single node
            #   * There's no Prüfer code for n=1, but let's handle it.
            if n == 1:
                # Just a single node: expression is sign + value
                node_info = {
                    0: {
                        "value": random.choice(self.numbers),
                        "sign": random.choice(["+", "-"]),
                        "operator": random.choice(["+", "*"])
                    }
                }
                # Expression for a single node
                expr_str = self._leaf_expr(node_info[0])
            else:
                # 2) Generate a random Prüfer code (0-based) of length n-2
                prufer = [random.randint(0, n-1) for _ in range(n-2)]
                # 3) Decode into edges
                edges = self._decode_prufer(prufer, n)
                # 4) Build adjacency
                adjacency = self._edges_to_adjacency(edges, n)

                # 5) Assign random numeric label, sign, operator to each node
                node_info = {}
                for node_id in range(n):
                    node_info[node_id] = {
                        "value": random.choice(self.numbers),
                        "sign": random.choice(["+", "-"]),
                        "operator": random.choice(["+", "*"])
                    }

                # 6) Root the tree at node 0 and build expression
                expr_str = self._build_expression_string(
                    current=0,
                    parent=None,
                    adjacency=adjacency,
                    node_info=node_info
                )

            # 7) Evaluate expression in Python, mod 10
            try:
                val = eval(expr_str)
                val_mod_10 = val % 10
            except Exception:
                val_mod_10 = 0

            # 8) Store into dataset
            self.data.append(expr_str + " = ")
            self.labels.append(str(val_mod_10))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            "instruction": "Evaluate this expression modulo 10.",
            "input": self.data[idx],
            "output": self.labels[idx],
        }

    # ----------------------------------------------------------------
    # (A) Decode Prüfer code -> edges (0-based)
    # ----------------------------------------------------------------
    def _decode_prufer(self, prufer, n):
        """
        Decode a 0-based Prüfer code of length (n-2) 
        into a list of edges on {0,1,...,n-1}.
        """
        freq = [0] * n
        for x in prufer:
            freq[x] += 1

        # min-heap of leaves
        leaf_queue = []
        for v in range(n):
            if freq[v] == 0:
                heapq.heappush(leaf_queue, v)

        edges = []
        for x in prufer:
            u = heapq.heappop(leaf_queue)
            edges.append((u, x))
            freq[x] -= 1
            if freq[x] == 0:
                heapq.heappush(leaf_queue, x)

        # last two vertices
        u = heapq.heappop(leaf_queue)
        w = heapq.heappop(leaf_queue)
        edges.append((u, w))

        return edges

    def _edges_to_adjacency(self, edges, n):
        """
        Build adjacency from edges. adjacency[u] = [neighbors...].
        """
        adjacency = {i: [] for i in range(n)}
        for a, b in edges:
            adjacency[a].append(b)
            adjacency[b].append(a)
        return adjacency

    # ----------------------------------------------------------------
    # (B) Build Expression from Root
    # ----------------------------------------------------------------
    def _build_expression_string(self, current, parent, adjacency, node_info):
        """
        DFS-based function that returns the expression string for
        the subtree rooted at `current`. We skip `parent` so as 
        not to go back up the tree.
        """
        children = [c for c in adjacency[current] if c != parent]

        # If no children, it's a leaf
        if len(children) == 0:
            return self._leaf_expr(node_info[current])

        # Otherwise, gather sub-expressions + our own numeric label
        subexprs = []
        for c in children:
            subexprs.append(self._build_expression_string(c, current, adjacency, node_info))

        # Also include current node's numeric label as an operand
        subexprs.append(str(node_info[current]["value"]))

        # shuffle them if operator is commutative
        # (for + or *, we consider them commutative)
        random.shuffle(subexprs)

        joined = (" " + node_info[current]["operator"] + " ").join(subexprs)

        if node_info[current]["sign"] == "-":
            return f"-({joined})"
        else:
            return f"({joined})"

    def _leaf_expr(self, node_data):
        """
        Returns string for a leaf node, e.g. "-5" or "7".
        """
        val_str = str(node_data["value"])
        if node_data["sign"] == "-":
            return "-" + val_str
        else:
            return val_str


def _evaluate_one_innermost(expr_str, mod = 10):
    """
    Finds all innermost bracketed sub-expressions in 'expr_str', collects them,
    and randomly chooses exactly one to evaluate and replace.
    If none exist, returns 'expr_str' unchanged.
    """
    stack = []
    innermost_positions = []

    for i, ch in enumerate(expr_str):
        if ch == '(':
            stack.append(i)
        elif ch == ')':
            start = stack.pop()
            sub = expr_str[start+1 : i]
            # If 'sub' doesn't contain '(' or ')', it's innermost
            if '(' not in sub and ')' not in sub:
                innermost_positions.append((start, i))

    # Randomly pick one innermost bracket
    if not innermost_positions:
        return expr_str
    start, end = random.choice(innermost_positions)
    sub_expr = expr_str[start+1 : end]

    val = eval(sub_expr)%mod
    val_str = str(val)

    # Replace that bracket with its evaluated value
    new_expr_str = expr_str[:start] + val_str + expr_str[end+1:]
    return new_expr_str

class RandomPruferExpressionDatasetCurriculum(Dataset):
    def __init__(
        self,
        size=1000,         # how many expressions to generate
        min_n=1,           # minimum number of nodes
        max_n=6,           # maximum number of nodes
        numbers=range(1, 10)  # possible numeric labels
    ):
        self.rdpf = RandomPruferExpressionDataset(size, min_n, max_n, numbers)
        for i in range(len(self.rdpf)):
            self.rdpf.labels[i] = _evaluate_one_innermost(self.rdpf.data[i])

    def __len__(self):
        return len(self.rdpf)

    def __getitem__(self, idx):
        return {
            "instruction": "Evaluate this expression modulo 10, try expanding one bracket.",
            "input": self.rdpf.data[idx],
            "output": self.rdpf.labels[idx],
        }


class ParityDataset(Dataset):
    def __init__(self, length, seq_len):
        self.length = length
        self.seq_len = seq_len
        self.data = []
        self.labels = []

        for _ in range(length):
            len = random.randint(1, seq_len)
            binary_str = ''.join(str(torch.randint(2, size=()).item()) for _ in range(len)) #len
            #for _ in range(seq_len - len):
            #  binary_str = binary_str + "0"
            data = ""
            for i in range(len):
            #for i in range(seq_len):
              data = data + binary_str[i] + " "
            self.data.append(data)
            self.labels.append(sum(int(bit) for bit in binary_str) % 2)
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {'instruction' : "Output the parity of this sequence.", "input":self.data[idx], "output": str(self.labels[idx])}


class QAPairDataset(Dataset):
    """
    PyTorch-compatible dataset for lines that look like:
        <question text> || <any COT> #### <answer>
    """

    def __init__(self, filepath: str):
        self._data = []                     # list of (question, answer) tuples
        self.label = []
        with open(filepath, encoding="utf-8") as f:
            for raw in f:
                line = raw.strip()
                if not line:
                    continue                # skip blanks

                # split on first "||"
                try:
                    question, rest = line.split("||", 1)
                except ValueError:
                    continue                # malformed → skip

                marker = "####"
                if marker not in rest:
                    continue                # no answer marker → skip

                answer = rest.rsplit(marker, 1)[1].strip()
                # self._data.append({"text": "Question:\n" + question + "\n" + "Answer:\n" + "#### " + answer + "\n" + "Question:\n"})
                self._data.append(question)
                self.label.append(answer)
    # ---------- required by torch.utils.data.Dataset ----------
    def __len__(self):
        """Number of examples in the dataset."""
        return len(self._data)

    def __getitem__(self, idx):
        """Return the (question, answer) pair at position `idx`."""
        return {'instruction' : "Answer the math question.", "input":self._data[idx], "output": self.label[idx]}

class QAPairDatasetWithCoT(Dataset):
    def __init__(self, filepath: str):
        self._data = []                     # list of (question, answer) tuples
        self.label = []
        with open(filepath, encoding="utf-8") as f:
            for raw in f:
                line = raw.strip()
                if not line:
                    continue                # skip blanks

                # split on first "||"
                try:
                    question, rest = line.split("||", 1)
                except ValueError:
                    continue                # malformed → skip

                # self._data.append({"text": "Question:\n" + question + "\n" + "Answer:\n" + "#### " + answer + "\n" + "Question:\n"})
                self._data.append(question)
                self.label.append(rest)
    # ---------- required by torch.utils.data.Dataset ----------
    def __len__(self):
        """Number of examples in the dataset."""
        return len(self._data)

    def __getitem__(self, idx):
        """Return the (question, answer) pair at position `idx`."""
        return {'instruction' : "Answer the math question.", "input":self._data[idx], "output": self.label[idx]}

class ParityDatasetSimplified(Dataset):
    def __init__(self, length, seq_len):
        self.length = length
        self.seq_len = seq_len
        self.data = []
        self.labels = []

        for _ in range(length):
            len = random.randint(1, seq_len)
            binary_str = ''.join(str(torch.randint(2, size=()).item()) for _ in range(len)) #len
            #for _ in range(seq_len - len):
            #  binary_str = binary_str + "0"
            data = ""
            for i in range(len):
            #for i in range(seq_len):
              data = data + binary_str[i] + " "
            self.data.append(data)
            self.labels.append(sum(int(bit) for bit in binary_str) % 2)
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {'instruction' : "Decide parity.", "input":self.data[idx], "output": str(self.labels[idx])}


class ParityDatasetFixedLength(Dataset):
    def __init__(self, length, seq_len):
        self.length = length
        self.seq_len = seq_len
        self.data = []
        self.labels = []

        for _ in range(length):
            len = seq_len
            binary_str = ''.join(str(torch.randint(2, size=()).item()) for _ in range(len)) #len
            data = ""
            for i in range(len):
              data = data + binary_str[i] + " "
            self.data.append(data)
            self.labels.append(sum(int(bit) for bit in binary_str) % 2)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return {'instruction' : "Output the parity of this sequence.", "input":self.data[idx], "output": str(self.labels[idx])}


class NuminaDataset(Dataset):
    """
    A dataset for the Numina Math dataset with filtering by tokenizer length.
    Filters examples where problem + solution <= 400 tokens, then samples:
    - 384,000 for training
    - 1,000 for testing  
    - 1,000 for validation
    All samples are drawn from the combined original train + test sets.
    Creates all splits at once for efficiency.
    """
    
    def __init__(self, 
                 split="train", 
                 tokenizer_name="meta-llama/Llama-3.2-1B", 
                 max_tokens=400, 
                 train_size=384000, 
                 test_size=1000, 
                 val_size=1000,
                 cache_dir="./numina_cache",
                 mode="problem_solution", # options: problem_solution, problem_cot_answer
                 generator_config=None,
                 cot_cache_dir="./numina_cot_cache",
                 # --- NEW SHARDING PARAMETERS ---
                 shard_size=6000,
                 shard_start_idx=0,
                 shard_end_idx=None):
        """
        Args:
            ... (all previous args) ...
            cot_cache_dir: Directory to cache *generated CoT* data.
            shard_size (int): Number of items per cache file (shard).
            shard_start_idx (int): The first shard index to load (inclusive).
            shard_end_idx (int): The last shard index to load (inclusive).
                                 If None, loads all shards from start_idx to the end.
        """
        self.split = split
        self.tokenizer_name = tokenizer_name
        self.max_tokens = max_tokens
        self.train_size = train_size
        self.test_size = test_size
        self.val_size = val_size
        self.cache_dir = cache_dir
        
        self.mode = mode
        self.generator_config = generator_config
        self.cot_cache_dir = cot_cache_dir
        
        # --- Store new sharding params ---
        self.shard_size = shard_size
        self.shard_start_idx = shard_start_idx
        self.shard_end_idx = shard_end_idx
        
        # Create cache directories
        os.makedirs(cache_dir, exist_ok=True)
        if self.mode == "problem_cot_answer":
            os.makedirs(cot_cache_dir, exist_ok=True)
            
        # Load or create processed data
        self.data, self.labels = self._load_or_create_data()
        
    def _load_or_create_data(self):
        """Load cached data or create and cache new data based on mode."""
        
        # 1. Load or create the BASE data (problem -> full_solution)
        # This is always needed, as CoT generation depends on it.
        base_cache_file = os.path.join(self.cache_dir, "numina_all_splits.json")
        
        if os.path.exists(base_cache_file):
            print(f"Loading BASE cached data from {base_cache_file}")
            with open(base_cache_file, 'r') as f:
                all_splits_data = json.load(f)
        else:
            print(f"Creating new BASE data for all splits...")
            all_splits_data = self._create_all_splits()
            
            # Cache all splits data
            with open(base_cache_file, 'w') as f:
                json.dump(all_splits_data, f)
            print(f"Cached all splits data to {base_cache_file}")
            
        # Get the base data for the current split
        base_problems = all_splits_data[self.split]['data']
        base_solutions = all_splits_data[self.split]['labels']

        # 2. Check the mode and return data
        if self.mode == "problem_solution":
            # --- Original behavior ---
            print("Mode 'problem_solution': Returning base problems and solutions.")
            return base_problems, base_solutions
            
        elif self.mode == "problem_cot_answer":
            # --- New behavior: Load or generate CoT ---
            print("Mode 'problem_cot_answer': Loading/Generating CoT and final answers from shards.")
            if self.generator_config is None:
                raise ValueError("generator_config must be provided for mode 'problem_cot_answer'")
            
            # Get model/iter IDs for file naming
            try:
                model_id = self.generator_config.get("model_name", "unknown_model").split('/')[-1]
                iter_id = self.generator_config.get("iteration", "unknown_iter")
            except Exception as e:
                raise ValueError(f"Invalid generator_config. Missing 'model_name' or 'iteration'? Error: {e}")

            # --- NEW SHARDING LOGIC ---
            
            # Accumulators for all data from all shards
            all_cot_data = []
            all_final_answers = []

            # Determine total number of shards
            total_examples = len(base_problems)
            if total_examples == 0:
                print("Warning: Base data is empty.")
                return [], []
            
            total_shards = math.ceil(total_examples / self.shard_size)
            if total_shards == 0 and total_examples > 0:
                total_shards = 1 # Handle case with fewer than shard_size examples

            # Determine the range of shards to process
            start_shard = self.shard_start_idx
            end_shard_inclusive = self.shard_end_idx
            if end_shard_inclusive is None:
                end_shard_inclusive = total_shards - 1
            
            # Cap the end shard at the actual max
            end_shard_inclusive = min(end_shard_inclusive, total_shards - 1)

            if start_shard > end_shard_inclusive:
                print(f"Warning: shard_start_idx ({start_shard}) is greater than end shard ({end_shard_inclusive}). No data will be loaded.")
                return [], []

            print(f"Processing shards from {start_shard} to {end_shard_inclusive} (out of {total_shards} total shards).")

            # Loop through the *specified range* of shards
            for shard_i in range(start_shard, end_shard_inclusive + 1):
                
                # Define the cache file for this specific shard
                shard_filename = f"cot_data_{self.split}_{model_id}_{iter_id}.shard_{shard_i}.json"
                cot_cache_file = os.path.join(self.cot_cache_dir, shard_filename)
                
                if os.path.exists(cot_cache_file):
                    # --- 1. Load from cache ---
                    print(f"Loading cached CoT data from {cot_cache_file} (Shard {shard_i})")
                    try:
                        with open(cot_cache_file, 'r') as f:
                            cached_data = json.load(f)
                        if not cached_data['data'] or not cached_data['labels']:
                            print(f"Warning: No effective data in cache file {cot_cache_file}. Will regenerate.")
                            os.remove(cot_cache_file) # Delete corrupted file
                        else: # Load data from cache
                            all_cot_data.extend(cached_data['data'])
                            all_final_answers.extend(cached_data['labels'])
                    except json.JSONDecodeError:
                        print(f"Warning: Corrupted cache file {cot_cache_file}. Will regenerate.")
                        os.remove(cot_cache_file) # Delete corrupted file
                    except Exception as e:
                        print(f"Error loading {cot_cache_file}: {e}. Will regenerate.")
                        if os.path.exists(cot_cache_file):
                            os.remove(cot_cache_file)
                
                if not os.path.exists(cot_cache_file):
                    # --- 2. Generate and cache ---
                    print(f"CoT cache not found for shard {shard_i}. Generating {cot_cache_file}...")
                    
                    # Get the slice of base data for *this shard*
                    start_idx = shard_i * self.shard_size
                    end_idx = min((shard_i + 1) * self.shard_size, total_examples)
                    
                    shard_problems = base_problems[start_idx:end_idx]
                    shard_solutions = base_solutions[start_idx:end_idx]

                    if not shard_problems:
                        print(f"No problems to process for shard {shard_i}. Skipping.")
                        continue

                    # Call the generator function. It will process this slice
                    # and save it to 'cot_cache_file'.
                    generated_data, generated_labels = self._generate_and_cache_cot(
                        shard_problems, 
                        shard_solutions, 
                        cot_cache_file
                    )
                    
                    # Add the newly generated data to our lists
                    all_cot_data.extend(generated_data)
                    all_final_answers.extend(generated_labels)
            
            print(f"Finished processing. Loaded {len(all_cot_data)} total examples from shards {start_shard} to {end_shard_inclusive}.")
            return all_cot_data, all_final_answers
            
        else:
            raise ValueError(f"Unknown mode: {self.mode}")
    
    def _create_all_splits(self):
        """Create all splits at once by filtering and sampling from Numina dataset."""
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
        
        # Load the original dataset
        print("Loading Numina dataset...")
        ds = datasets.load_dataset("AI-MO/NuminaMath-CoT")
        
        # Combine train and test data
        all_data = []
        for split_name in ["train", "test"]:
            for i in range(len(ds[split_name])):
                problem = ds[split_name][i]["problem"]
                solution = ds[split_name][i]["solution"]
                all_data.append({
                    "problem": problem,
                    "solution": solution,
                    "original_split": split_name
                })
        
        print(f"Total examples before filtering: {len(all_data)}")
        
        # Filter by tokenizer length
        filtered_data = []
        for item in all_data:
            problem = item["problem"]
            solution = item["solution"]
            total_tokens = len(tokenizer.encode(problem + solution))
            if total_tokens <= self.max_tokens:
                filtered_data.append(item)
        
        print(f"Examples after filtering (≤{self.max_tokens} tokens): {len(filtered_data)}")
        
        # Shuffle the filtered data
        random.shuffle(filtered_data)
        
        # Sample data for all splits at once
        train_data = filtered_data[:self.train_size]
        test_data = filtered_data[self.train_size:self.train_size + self.test_size]
        val_data = filtered_data[self.train_size + self.test_size:self.train_size + self.test_size + self.val_size]
        
        print(f"Sampled examples:")
        print(f"  Train: {len(train_data)}")
        print(f"  Test: {len(test_data)}")
        print(f"  Validation: {len(val_data)}")
        
        # Create the splits data structure
        all_splits_data = {
            "train": {
                "data": [item["problem"] for item in train_data],
                "labels": [item["solution"] for item in train_data]
            },
            "test": {
                "data": [item["problem"] for item in test_data],
                "labels": [item["solution"] for item in test_data]
            },
            "val": {
                "data": [item["problem"] for item in val_data],
                "labels": [item["solution"] for item in val_data]
            }
        }
        
        return all_splits_data
    
    def _generate_and_cache_cot(self, base_problems, base_solutions, cot_cache_file):
        """
        Generates CoT for each problem using the specified generator model
        and caches the results.
        
        *** This version is optimized for batch processing. ***
        
        Returns:
            (list, list): (generated_cot_data, final_answers)
        """
        
        # 1. Check for imports
        import get_model_llama as get_model
        from get_model_llama import LinkedListCache
        if get_model is None or LinkedListCache is None:
            raise ImportError("Failed to import 'get_model_llama'. "
                              "Cannot generate CoT data.")

        # 2. Load Generator Config
        config = self.generator_config
        model_name = config["model_name"]
        checkpoint_folder = config["checkpoint_folder"]
        iteration = config.get("iteration", 18000)
        device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
        max_input_len = config.get("max_input_len", 1000)
        max_new_tokens = config.get("max_new_tokens", 600)
        
        # --- Get batch size from config, default to 8 ---
        batch_size = config.get("batch_size", 8)
        
        # 3. Load Model and Tokenizer
        print(f"Loading generator model: {model_name} (Iteration {iteration})")
        
        config_path = os.path.join(checkpoint_folder, "save_configs.json")
        if not os.path.exists(config_path):
             raise FileNotFoundError(f"Generator config file not found: {config_path}")
        with open(config_path, "r") as f:
            model_config = json.load(f)
        bridges = model_config["bridges"]

        model, tokenizer = get_model.get_model(improved=True, dataType=torch.float32, bridges=bridges, r=140, model_name=model_name)
        prepared_model, _ = get_model.get_model(improved=False, dataType=torch.float32, model_name=model_name)
        llamamodel = prepared_model.model
        model.model.load_state_dict(llamamodel.state_dict(), strict=False)
        model.lm_head = prepared_model.lm_head
        del prepared_model
        del llamamodel # Free memory
        gc.collect()

        checkpoint_path = os.path.join(checkpoint_folder, f"pretrain_improved_{iteration}.pth")
        if not os.path.exists(checkpoint_path):
             raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
        model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"), strict=False)

        if tokenizer.pad_token is None:
            # Set pad token for batching
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = model.config.eos_token_id
            
        # --- IMPORTANT: Set padding side for batch generation ---
        tokenizer.padding_side = "left"

        model.to(device)
        model.eval()
        
        print(f"Generator model loaded to {device}. Starting batch generation (batch_size={batch_size}).")

        # 4. Batch Generation Loop
        generated_cot_data = [] # This will be our new 'data'
        final_answers = []      # This will be our new 'labels'
        
        few_shot_prompt = "Question:\n" # From your script
        total_examples = len(base_problems)

        # Create all input prompts first
        all_input_texts_for_model = [
            few_shot_prompt + problem_text + "\nAnswer:\n" 
            for problem_text in base_problems
        ]

        # Process in batches
        for i in range(0, total_examples, batch_size):
            # Get the current batch of prompts and their original indices
            batch_prompts = all_input_texts_for_model[i : i + batch_size]
            batch_indices = range(i, min(i + batch_size, total_examples))
            
            # 4a. Tokenize the batch
            inputs = tokenizer(
                batch_prompts, 
                return_tensors='pt', 
                padding=True, 
                truncation=True, 
                max_length=max_input_len,
                add_special_tokens=False
            ).to(device)

            # 4b. Generate with the model
            current_context_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']
            
            # --- START of Manual Batch Decoding Loop ---
            
            past_key_values = None
            current_batch_size = current_context_ids.shape[0]
            
            # Store generated tokens for each item in the batch
            generated_ids_list = [[] for _ in range(current_batch_size)]
            
            # Track which sequences are finished
            finished_sequences = torch.zeros(current_batch_size, dtype=torch.bool, device=device)

            try:
                with torch.no_grad():
                    for step in range(max_new_tokens):
                        # Prepare model inputs for this step
                        if past_key_values is not None:
                            # After step 0, only pass the last token
                            model_input_ids = current_context_ids[:, -1:]
                        else:
                            # On step 0, pass the whole prompt
                            model_input_ids = current_context_ids
                            # This is where your custom cache is initialized
                            past_key_values = LinkedListCache() 

                        # Forward pass
                        outputs = model(
                            input_ids=model_input_ids,
                            attention_mask=attention_mask,
                            past_key_values=past_key_values,
                            use_cache=True
                        )

                        # Get logits for the next token
                        logits = outputs.logits
                        next_token_logits = logits[:, -1, :] # Shape: [batch_size, vocab_size]

                        # Greedy decoding (same as temp=1.0, top_p=1.0)
                        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) # Shape: [batch_size, 1]
                        
                        # --- Update sequences ---
                        # Find which sequences are *not* finished yet
                        not_finished = ~finished_sequences

                        # Store the new token for all *unfinished* sequences
                        for j in range(current_batch_size):
                            if not_finished[j]:
                                token_item = next_token_id[j].item()
                                generated_ids_list[j].append(token_item)

                        # Check if any *unfinished* sequences just hit the EOS token
                        newly_finished = (next_token_id.squeeze(-1) == tokenizer.eos_token_id) & not_finished
                        finished_sequences.logical_or_(newly_finished)
                        
                        # Stop if all sequences in the batch are finished
                        if finished_sequences.all():
                            break

                        # --- Prepare for next iteration ---
                        # Append the new token ID to the context sequence
                        current_context_ids = torch.cat([current_context_ids, next_token_id], dim=-1)

                        # Update attention mask by appending a 1 for the new token
                        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id)], dim=-1)

                        # Update KV cache
                        past_key_values = outputs.past_key_values
                        
                    # --- END of Generation Loop (for step in max_new_tokens) ---

                # 4d. Post-process and Store results for this batch
                for j in range(current_batch_size):
                    # Decode the collected IDs for this batch item
                    generated_cot_only = tokenizer.decode(generated_ids_list[j], skip_special_tokens=True)

                    # Get the original problem text
                    original_problem_idx = batch_indices[j]
                    original_problem_text = base_problems[original_problem_idx]
                    
                    # Re-build the string in your desired format
                    full_solution_text = few_shot_prompt + original_problem_text + "\nReasoning:\n" + generated_cot_only
                    generated_cot_data.append(full_solution_text)

                    # Get the original solution to extract the ground truth
                    solution_text = base_solutions[original_problem_idx]

                    # Extract Ground Truth Answer
                    gt_answer = ""
                    boxed_match = re.search(r'\\boxed\{(.*)\}', solution_text)
                    if boxed_match:
                        gt_answer = boxed_match.group(1).strip()
                    else:
                        boxed_matches = re.findall(r'\\boxed\{(.*)\}', solution_text)
                        if boxed_matches:
                            gt_answer = boxed_matches[-1].strip()
                        else:
                            gt_answer = "Unknown"
                    
                    final_answers.append(gt_answer)

            except Exception as e:
                print(f"Error during manual batch generation for batch starting at {i}: {e}")
                import traceback
                traceback.print_exc()
                continue
            
            print(f"Generated CoT for {min(i + batch_size, total_examples)}/{total_examples} examples...")
        
        # 5. Save to Cache and Return
        print(f"Finished generation. Saving {len(generated_cot_data)} examples to {cot_cache_file}")
        
        # Clean up model from GPU
        del model
        del tokenizer
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()
            
        # --- FIX ---
        # Your code had `full_solution_text` here, which was just a single string.
        # This saves the complete list 'generated_cot_data'.
        cache_to_save = {
            "data": generated_cot_data,
            "labels": final_answers
        }
        
        with open(cot_cache_file, 'w') as f:
            json.dump(cache_to_save, f, indent=2)
            
        return generated_cot_data, final_answers
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # --- Modified to handle both modes ---
        if self.mode == "problem_solution":
            # Original behavior
            return {
                'instruction': "Solve this math problem step by step.",
                'input': self.data[idx],  # This is the problem
                'output': self.labels[idx] # This is the full solution
            }
        elif self.mode == "problem_cot_answer":
            # New behavior
            return {
                'instruction': "Predict final answer.",
                'input': self.data[idx],   # This is the generated CoT
                'output': self.labels[idx]  # This is the final answer
            }

class CommonGen(Dataset):
    def __init__(self, dataset_name="allenai/common_gen", split="train"):
        dataset = load_dataset(dataset_name)
        self.data = []
        self.labels = []
        self.items = []
        self.idx_problem = []
        indexer = list(range(len(dataset[split])))
        random.shuffle(indexer)
        for idx in indexer:
            item = dataset[split][idx]
            if not (split == 'train'):
                self.items.append(item['concepts'])
            else:
                self.items.append(",".join(item['concepts']))
            self.data.append("; ".join(item['concepts']) + "\nAnswer:")
            self.labels.append(item['target'])
            self.idx_problem.append(str(item['concept_set_idx']))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'instruction': "Generate a sentence that contains all the concepts.",
            'input': self.data[idx],
            'output': self.labels[idx],
            'item': self.items[idx], 
            'idx_problem': self.idx_problem[idx]
        }