import os
import datasets
import pyarrow as pa
import pyarrow.parquet as pq
import json
import argparse
from transformers import AutoTokenizer, SchedulerType

BACKBONE = "Qwen/Qwen3-8B"
# BACKBONE = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(BACKBONE)

THINKING = False
# THINKING = True

def _to_large(field: pa.Field) -> pa.Field:
    t = field.type
    if pa.types.is_string(t):  return pa.field(field.name, pa.large_string(), field.nullable, field.metadata)
    if pa.types.is_binary(t):  return pa.field(field.name, pa.large_binary(), field.nullable, field.metadata)
    if pa.types.is_list(t):    return pa.field(field.name, pa.large_list(_to_large(pa.field("item", t.value_type)).type),
                                              field.nullable, field.metadata)
    if pa.types.is_struct(t):  return pa.field(field.name,
        pa.struct([_to_large(pa.field(f.name, f.type, f.nullable, f.metadata)) for f in t]),
        field.nullable, field.metadata)
    return field

def _large_schema(schema: pa.Schema) -> pa.Schema:
    return pa.schema([_to_large(pa.field(f.name, f.type, f.nullable, f.metadata)) for f in schema])

def write_rowgrouped_large(ds, path: str, rows_per_group: int = 32):
    """Cast to LargeString/LargeList and write many small row groups."""
    tbl: pa.Table = ds.data.table
    tbl = tbl.cast(_large_schema(tbl.schema))  # avoid 32-bit offset overflow
    # DO NOT combine_chunks() here — we want smaller arrays per row group
    n = len(tbl)
    writer = None
    try:
        for start in range(0, n, rows_per_group):
            chunk = tbl.slice(start, min(rows_per_group, n - start))
            if writer is None:
                writer = pq.ParquetWriter(path, chunk.schema, compression="zstd")
            writer.write_table(chunk)
    finally:
        if writer is not None:
            writer.close()


def add_chat_template(ex):
    print(ex)
    messages = [{"role": "user", "content": ex["prompt"]}]#[0]["content"]
    full = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=THINKING
    )
    #print(full)
    return {"prompt": full}



def make_map_fn(split, source=None, our_dataset=False, use_system_prompt = True):
        def process_fn(example, idx):
            if source is None:
                data_source = example.pop("source")
            else:
                data_source = source
            question = example.pop("prompt")
            solution = example.pop("answer")
            global_id = example.pop("idx")

            elo = 0
            if our_dataset:
                tests = example.pop("tests")

                description = example.pop("description")
                reward_style = example.pop("kind")
                if split == "train" and "achievement_prior" in example.keys():
                    achievement_prior = example.pop("achievement_prior")
                else:
                    achievement_prior = 0
                data_source = example.pop("dataset")
                elo = example.pop("elo")
                if reward_style == "code":
                    solution = tests

                extra_info = {
                    "split": split,
                    "index": f"{global_id}",
                    "description": description,
                    "problem": question,
                    "elo": elo,
                    "achievement_prior": achievement_prior
                }
                data = {
                    "data_source": data_source,
                    "prompt": [
                        {
                            "role": "user",
                            "content": question,
                        }
                    ],
                    "ability": reward_style,
                    "reward_model": {"style": reward_style, "ground_truth": solution},
                    "extra_info": extra_info
                    }
            else:
                data = {
                    "data_source": data_source,
                    "prompt": [
                        {
                            "role": "user",
                            "content": question,
                        }
                    ],
                    "ability": "math",
                    "reward_model": {"style": "rule", "ground_truth": solution},
                    "extra_info": {
                        "split": split,
                        "index": f"{data_source}-{idx}",
                        "problem": question,
                    },
                }

            return data

        return process_fn

def run_proprocessing(data_source, num_proc=4):
    use_system_prompt = False

    our_dataset = "OURS_verifiable-corpus" in data_source
    print(data_source)
    train_dataset = datasets.load_dataset("json", data_files=os.path.join(data_source, 'train.json'), split='train')
    try:
        test_dataset = datasets.load_dataset("json", data_files=os.path.join(data_source, 'test.json'), split='train')
    except:
        test_dataset = datasets.load_dataset("json", data_files=os.path.join(data_source, 'test.json'), split='test')
    print("Map Datasets")
    processed_shards = []
    num_shards = 4
    for i in range(num_shards):
        train_shard = train_dataset.shard(num_shards=num_shards, index=i)
        train_shard = train_shard.map(function=make_map_fn("train", data_source,our_dataset=our_dataset, use_system_prompt = use_system_prompt), with_indices=True,num_proc=num_proc)
        processed_shards.append(train_shard)
    train_ds = datasets.concatenate_datasets(processed_shards)
    print(train_ds)
    processed_shards = []
    num_shards = 4
    for i in range(num_shards):
        test_shard = test_dataset.shard(num_shards=num_shards, index=i)
        test_shard = test_shard.map(function=make_map_fn("test", data_source,our_dataset=our_dataset, use_system_prompt = use_system_prompt), with_indices=True,num_proc=num_proc)
        processed_shards.append(test_shard)
    test_ds = datasets.concatenate_datasets(processed_shards)
    print(test_ds)

    out_train = os.path.join(data_source, "train.parquet")
    out_test  = os.path.join(data_source, "test.parquet")
    write_rowgrouped_large(train_ds, out_train)
    write_rowgrouped_large(test_ds, out_test)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Produce sorted dataset that can be used for training on most relevant questions."
    )
    parser.add_argument(
        "--data_source", type=str,
        help="HF dataset name."
    )
    args = parser.parse_args()
    data_source = args.data_source
    run_proprocessing(data_source=data_source)
