#!/usr/bin/env python
"""
Prepare token-level binary shards and write a metadata.json file.
"""
import argparse
import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Literal
from collections import OrderedDict
import itertools

import numpy as np
from datasets import Dataset, concatenate_datasets, load_dataset
from dotenv import load_dotenv
from huggingface_hub import HfApi, hf_hub_download
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers.utils import logging

logging.set_verbosity(40)

# Load environment variables
load_dotenv(Path(__file__).parent.parent.parent / ".env")

def memmap_write(
    fname: Path,
    arr: List[List[int]],
    dtype: np.dtype = np.uint16,
) -> None:
    """
    Write array data to a memory-mapped file.

    Args:
        fname: Path to output file
        arr: List of arrays to write
        dtype: NumPy data type for the memory-mapped array
    """
    
    total = sum(len(a) for a in arr)
    mmap = np.memmap(fname, dtype=dtype, mode="w+", shape=(total,))
    idx = 0
    for a in tqdm(arr, desc="writing", total=len(arr)):
        mmap[idx : idx + len(a)] = a
        idx += len(a)
    mmap.flush()


def prep(
    num_proc: int,
    tokenizer: AutoTokenizer,
    max_length: int,
    length_strategy: Literal["truncate", "drop", "none"],
    sample_pct: float = 1.0,
) -> Dict[str, Dataset]:

    dset_name = "HuggingFaceFW/fineweb-edu"
    ds_all = load_dataset(dset_name, split="train", name="sample-100BT", streaming=True)
    ds = Dataset.from_generator(lambda: itertools.islice(ds_all, 15_000_000))    
    print("ds: ", ds)

    # Sample dataset if requested
    if sample_pct < 1.0:
        sample_size = int(len(ds) * sample_pct)
        ds = ds.select(range(sample_size))
        print(f"Sampling {sample_pct*100}% of data: {sample_size} examples")

    ds = ds.select_columns(["text"])

    splits = ds.train_test_split(test_size=0.02, seed=42)

    train, test = splits["train"], splits["test"]

    train = train.map(lambda ex: {"split": "train"}, num_proc=num_proc)
    test = test.map(lambda ex: {"split": "test"}, num_proc=num_proc)

    ds = concatenate_datasets([train, test])

    print("Dataset columns:", ds.column_names)

    # --------------------------------------------------------- #
    # 1. tokenisation                                           #
    # --------------------------------------------------------- #

    def tok_fn(ex: Dict[str, Any]) -> Dict[str, Any]:

        ids = tokenizer.encode(ex["text"], add_special_tokens=False)
        ids.append(tokenizer.eos_token_id)

        if length_strategy == "truncate":
            ids = ids[:max_length]
            ids[-1] = tokenizer.eos_token_id

        return {"ids": ids, "len": len(ids)}

    ds = ds.map(tok_fn, num_proc=num_proc)

    # If dropping is enabled, remove examples longer than max_length
    if length_strategy == "drop":
        ds = ds.filter(lambda ex: ex["len"] <= max_length, num_proc=num_proc)

    # --------------------------------------------------------- #
    # 2. split into train/test                                  #
    # --------------------------------------------------------- #

    data = OrderedDict()
    train = ds.filter(lambda ex: ex["split"] == "train", num_proc=num_proc)
    test = ds.filter(lambda ex: ex["split"] == "test", num_proc=num_proc)
    data["fineweb"] = {
        "train": train,
        "test": test,
    }

    return data


def write(
    datasets: Dict[str, Dataset],
    out_dir: Path,
    max_length: int,
    tokenizer_name: str,
    tokenizer: AutoTokenizer,
    length_strategy: Literal["truncate", "drop", "none"],
    shard_size: int = 500_000,  # Number of examples per shard
) -> None:
    """Write datasets to binary files (sharded if large) and collect metadata."""

    meta: dict[str, Any] = {}
    total_tokens_train = 0
    total_tokens_test = 0
    labels = sorted(list(datasets.keys()))

    for label, splits in datasets.items():

        print("label", label)
        print("splits", splits)

        meta[label] = {
            "train": {},
            "test": {},
        }

        for split in ["train", "test"]:

            subset = splits[split]
            num_examples = len(subset)
            
            print(f"Processing {split}: {num_examples} examples")
            
            # Determine if we need sharding
            if num_examples > shard_size:
                num_shards = (num_examples + shard_size - 1) // shard_size  # Ceiling division
                print(f"Sharding {split} into {num_shards} shards ({num_examples} examples)")
                
                shard_files = []
                split_total_tokens = 0
                
                for shard_idx in range(num_shards):
                    start_idx = shard_idx * shard_size
                    end_idx = min((shard_idx + 1) * shard_size, num_examples)
                    
                    print(f"Selecting shard {shard_idx+1}/{num_shards}: indices {start_idx}-{end_idx}")
                    shard_subset = subset.select(range(start_idx, end_idx))
                    
                    out_path = out_dir / f"{label}_{split}_{shard_idx:03d}.bin"
                    if out_path.exists():
                        os.remove(out_path)
                    
                    print(f"Writing shard {shard_idx+1}/{num_shards}: {out_path.name}")
                    memmap_write(
                        out_path,
                        shard_subset["ids"],
                        np.uint16,
                    )
                    
                    # Accumulate tokens per shard to avoid processing full dataset
                    split_total_tokens += int(np.sum(shard_subset["len"]))
                    shard_files.append(out_path.name)
                    
                    print(f"Completed shard {shard_idx+1}/{num_shards}")
                
                total_tokens = split_total_tokens
                # Get example from last shard
                example_text = tokenizer.decode(shard_subset[-1]["ids"], skip_special_tokens=False)
            else:
                # Small dataset, write as single file
                out_path = out_dir / f"{label}_{split}.bin"
                if out_path.exists():
                    os.remove(out_path)
                
                print(f"Writing {split}: {out_path.name} ({num_examples} examples)")
                memmap_write(
                    out_path,
                    subset["ids"],
                    np.uint16,
                )
                shard_files = [out_path.name]
                
                # ---------- per‑split statistics ----------
                total_tokens = int(np.sum(subset["len"]))
                example_text = tokenizer.decode(subset[-1]["ids"], skip_special_tokens=False)

            meta[label][split] = {
                "total_tokens": total_tokens,
                "example": example_text,
                "num_shards": len(shard_files),
                "shard_files": shard_files,
            }

            if split == "train":
                total_tokens_train += total_tokens
            else:
                total_tokens_test += total_tokens

    # ---------- global statistics ----------
    meta["all"] = {
        "total_tokens_train": total_tokens_train,
        "total_tokens_test": total_tokens_test,
        "tokenizer": tokenizer_name,
        "vocab_size": len(tokenizer),
        "max_length": max_length,
        "labels": labels,
        "length_strategy": length_strategy,
    }

    # ---------------------------------------------------- #
    # dump metadata.json                                   #
    # ---------------------------------------------------- #
    with open(out_dir / "metadata.json", "w") as f:
        json.dump(meta, f, indent=2, ensure_ascii=False, default=str)


# --------------------------------------------------------------------------- #
# main preparation sequence                                                   #
# --------------------------------------------------------------------------- #

def run(
        out_dir: Path | None, 
        num_proc: int, 
        max_length: int, 
        length_strategy: str, 
        tokenizer_name: str,
        download_bins: bool,
        upload_bins: bool,
        sample_pct: float,
        shard_size: int,
    ) -> None:

    if out_dir is None:
        
        dir_str = "fineweb"
        cur_dir = Path(__file__).parent
        out_dir = cur_dir / dir_str
    
    print("out_dir:", out_dir)

    out_dir.mkdir(parents=True, exist_ok=True)

    # Get HF token from environment
    hf_token = os.getenv("HF_TOKEN")
    repo_id = "erol-AE/GR-MoE"
    subfolder = "fineweb_45M"

    # Download bins if requested
    if download_bins:
        print(f"Downloading .bin files from {repo_id}/{subfolder}...")
        api = HfApi(token=hf_token)
        
        # List all files in the subfolder
        try:
            repo_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset", token=hf_token)
            bin_files = [f for f in repo_files if f.startswith(subfolder) and f.endswith('.bin')]
            
            # Also download metadata.json
            metadata_files = [f for f in repo_files if f.startswith(subfolder) and f.endswith('metadata.json')]
            
            all_files = bin_files + metadata_files
            
            if not all_files:
                print(f"No .bin or metadata.json files found in {repo_id}/{subfolder}")
            else:
                for file_path in tqdm(all_files, desc="Downloading files"):
                    local_filename = Path(file_path).name
                    local_path = out_dir / local_filename
                    
                    # Download to a temporary location that preserves repo structure
                    temp_download = hf_hub_download(
                        repo_id=repo_id,
                        filename=file_path,
                        repo_type="dataset",
                        token=hf_token,
                    )
                    
                    # Copy to the desired location (flat structure in out_dir)
                    shutil.copy2(temp_download, local_path)
                    print(f"Downloaded {file_path} to {local_path}")
                    
                print("Download complete!")
        except Exception as e:
            print(f"Error downloading files: {e}")
            raise
        
        return

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    data = prep(
        num_proc=num_proc,
        tokenizer=tokenizer,
        max_length=max_length,
        length_strategy=length_strategy,
        sample_pct=sample_pct,
    )

    # Write datasets and metadata
    write(
        datasets=data,
        out_dir=out_dir,
        max_length=max_length,
        tokenizer_name=tokenizer_name,
        tokenizer=tokenizer,
        length_strategy=length_strategy,
        shard_size=shard_size,
    )

    print("Done - binary shards + metadata.json written to", out_dir)

    # Upload bins if requested
    if upload_bins:
        print(f"Uploading .bin files to {repo_id}/{subfolder}...")
        api = HfApi(token=hf_token)
        
        # Ensure repo exists (will not error if it already exists)
        try:
            api.create_repo(repo_id=repo_id, token=hf_token, exist_ok=True, repo_type="dataset")
            print(f"Dataset repository {repo_id} ready")
        except Exception as e:
            print(f"Note: Could not create/verify repo (it may already exist): {e}")
        
        # Find all .bin files and metadata.json in out_dir
        bin_files = list(out_dir.glob("*.bin"))
        metadata_file = out_dir / "metadata.json"
        
        files_to_upload = bin_files.copy()
        if metadata_file.exists():
            files_to_upload.append(metadata_file)
        
        if not files_to_upload:
            print(f"No .bin or metadata.json files found in {out_dir} to upload")
        else:
            for file_path in tqdm(files_to_upload, desc="Uploading files"):
                try:
                    api.upload_file(
                        path_or_fileobj=str(file_path),
                        path_in_repo=f"{subfolder}/{file_path.name}",
                        repo_id=repo_id,
                        repo_type="dataset",
                        token=hf_token,
                    )
                    print(f"Uploaded {file_path.name} to {repo_id}/{subfolder}")
                except Exception as e:
                    print(f"Error uploading {file_path.name}: {e}")
                    raise
            
            print("Upload complete!")


# --------------------------------------------------------------------------- #
# CLI                                                                         #
# --------------------------------------------------------------------------- #

if __name__ == "__main__":

    ap = argparse.ArgumentParser("Prepare fineweb dataset")
    ap.add_argument("--out_dir", default=None, help="directory to write .bin files")
    ap.add_argument("--num_proc", type=int, default=32)
    ap.add_argument("--max_length", type=int, default=-1)
    ap.add_argument("--length_strategy", type=str, default="none", choices=["truncate", "drop", "none"])
    ap.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neo-125M")
    ap.add_argument("--download_bins", action="store_true")
    ap.add_argument("--upload_bins", action="store_true")
    ap.add_argument("--sample", type=float, default=1.0, help="Fraction of data to use (0.0-1.0), e.g., 0.01 for 1%%")
    ap.add_argument("--shard_size", type=int, default=500_000, help="Number of examples per shard file (default: 500k)")
    args = ap.parse_args()

    run(
        out_dir=args.out_dir,
        num_proc=args.num_proc,
        max_length=args.max_length,
        length_strategy=args.length_strategy,
        tokenizer_name=args.tokenizer,
        download_bins=args.download_bins,
        upload_bins=args.upload_bins,
        sample_pct=args.sample,
        shard_size=args.shard_size,
    )