import os
import sys

os.environ["EINOPS_BACKEND"] = "jax"
sys.modules.pop("tensorflow", None)

import einops

if hasattr(einops, "set_default_backend"):
    einops.set_default_backend("jax")

import multiprocessing

import multiprocess
import numpy as np
from datasets import disable_caching, load_from_disk

multiprocess.set_start_method("forkserver", force=True)
multiprocessing.set_start_method("forkserver", force=True)


def get_keys(inputs, pad_token_id, eos_token_id):
    targets = np.roll(inputs, -1, axis=-1)
    inputs[:, -1] = pad_token_id

    final_len = inputs.shape[1]
    pos_base = np.arange(final_len, dtype=np.int32)
    pos_tiled = np.tile(pos_base, (inputs.shape[0], 1))

    token_mask = ((inputs != pad_token_id) & (inputs != eos_token_id)).astype(np.int32)
    positions = np.where(token_mask, pos_tiled, -1)
    num_tokens = np.sum(positions >= 0, axis=-1)

    return inputs, targets, token_mask, positions, num_tokens


def dict_encode_batch(batch, tokenizer, pad_token_id, eos_token_id, target_column):
    encoded = tokenizer.encode_batch_fast(batch[target_column])
    ids = np.asarray([x.ids for x in encoded])

    inputs, targets, token_mask, positions, num_tokens = get_keys(
        ids, pad_token_id, eos_token_id
    )

    return {
        "inputs": inputs,
        "targets": targets,
        "positions": positions,
        "num_tokens": num_tokens,
        "token_mask": token_mask,
    }


def load_and_tokenize(
    dataset_dir,
    tokenizer,
    batch_size,
    num_processes,
    seed=None,
    target_column="smiles",
    caching=False,
    limit: int | None = None,
):
    if not caching:
        disable_caching()

    loaded_dataset = load_from_disk(dataset_dir)
    if seed is not None:
        loaded_dataset = loaded_dataset.shuffle(seed)

    if limit is not None:
        limit = (limit // batch_size) * batch_size or batch_size
        loaded_dataset = loaded_dataset.select(range(min(limit, len(loaded_dataset))))

    return (
        loaded_dataset.batch(batch_size, drop_last_batch=True, num_proc=num_processes)
        .map(
            dict_encode_batch,
            remove_columns=[target_column],
            num_proc=num_processes,
            fn_kwargs={
                "tokenizer": tokenizer,
                "pad_token_id": tokenizer.pad_token_id,
                "eos_token_id": tokenizer.eos_token_id,
                "target_column": target_column,
            },
        )
        .with_format("numpy")
    )
