

import requests
from tqdm import tqdm
import numpy as np


def download_file(url: str, fname: str, chunk_size=1024):
    
    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    with open(fname, "wb") as file, tqdm(
        desc=fname,
        total=total,
        unit="iB",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in resp.iter_content(chunk_size=chunk_size):
            size = file.write(data)
            bar.update(size)


HEADERS_INFO = {
    "gpt-2": {
        "magic": 20240520,
        "version": 1,
        "token_dtype": np.uint16,
    },
    "llama-3": {
        "magic": 20240801,
        "version": 7,
        "token_dtype": np.uint32,
    },
    "tinyllama": {
        "magic": 20240801,
        "version": 7,
        "token_dtype": np.uint32,
    },
    "smollm": {
        "magic": 20240801,
        "version": 7,
        "token_dtype": np.uint32,
    },
}

def write_datafile(filename, toks, model_desc="gpt-2"):
    
    assert len(toks) < 2**31, "token count too large" 
    assert model_desc in ["gpt-2", "llama-3", "tinyllama", "smollm"], f"unknown model descriptor {model_desc}"
    info = HEADERS_INFO[model_desc]
    
    header = np.zeros(256, dtype=np.int32) 
    header[0] = info["magic"]
    header[1] = info["version"]
    header[2] = len(toks) 
    
    toks_np = np.array(toks, dtype=info["token_dtype"])
    
    num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize)
    print(f"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format")
    with open(filename, "wb") as f:
        f.write(header.tobytes())
        f.write(toks_np.tobytes())

def write_evalfile(filename, datas):
    
    
    header = np.zeros(256, dtype=np.int32)
    header[0] = 20240522 
    header[1] = 1 
    header[2] = len(datas) 
    header[3] = 0 
    
    longest_example_bytes = 0 
    full_stream = [] 
    assert len(datas) < 2**16, "too many examples?"
    for idx, data in enumerate(datas):
        stream = []
        
        stream.append(2**16-1) 
        stream.append(0) 
        stream.append(idx) 
        stream.append(data["label"]) 
        ending_tokens = data["ending_tokens"]
        assert len(ending_tokens) == 4, "expected 4 completions for now? can relax later"
        stream.append(len(ending_tokens)) 
        
        ctx_tokens = data["ctx_tokens"]
        assert all(0 <= t < 2**16-1 for t in ctx_tokens), "bad context token"
        stream.append(len(ctx_tokens))
        stream.extend(ctx_tokens)
        
        for end_tokens in ending_tokens:
            assert all(0 <= t < 2**16-1 for t in end_tokens), "bad completion token"
            stream.append(len(end_tokens))
            stream.extend(end_tokens)
        
        nbytes = len(stream)*2 
        assert nbytes < 2**16, "example too large?"
        stream[1] = nbytes 
        longest_example_bytes = max(longest_example_bytes, nbytes)
        full_stream.extend(stream)
    
    stream_np = np.array(full_stream, dtype=np.uint16)
    
    assert 0 < longest_example_bytes < 2**16, f"bad longest_example"
    header[3] = longest_example_bytes
    
    print(f"writing {len(datas):,} examples to {filename}")
    with open(filename, "wb") as f:
        f.write(header.tobytes())
        f.write(stream_np.tobytes())