from transformers import AutoTokenizer
from pathlib import Path

def load_tokenized_data(
    ctx_len: int,
    tokenizer: AutoTokenizer,
    dataset_repo: str,
    dataset_split: str,
    dataset_name: str = "",
    dataset_row: str = "text",
    seed: int = 22,
    cache: str = "cache",
):
    """
    Load a huggingface dataset, tokenize it, and shuffle.
    """
    from datasets import load_dataset, load_from_disk
    from transformer_lens import utils
    print(dataset_repo,dataset_name,dataset_split)
    
#    tokenizer_name = tokenizer.name_or_path.split("/")[-1]
#    if Path(f"{cache}/tokenized/{tokenizer_name}/{dataset_repo.split('/')[-1]}_{dataset_split}").exists():
#        print("Loading tokenized data from cache")
#        tokens = load_from_disk(f"{cache}/tokenized/{tokenizer_name}/{dataset_repo.split('/')[-1]}_{dataset_split}")
#    else:
#        data=load_dataset(dataset_repo, name=dataset_name, split=dataset_split, cache_dir=cache)
#        tokens = utils.tokenize_and_concatenate(data, tokenizer, max_length=ctx_len,column_name=dataset_row)
#        tokens.save_to_disk(f"{cache}/tokenized/{tokenizer_name}/{dataset_repo.split('/')[-1]}_{dataset_split}")
    
        
        
    data = load_dataset(dataset_repo, name=dataset_name, split=dataset_split, cache_dir=cache, trust_remote_code=True)
    tokens = utils.tokenize_and_concatenate(data, tokenizer, max_length=ctx_len,column_name=dataset_row)

    tokens = tokens.shuffle(seed)["tokens"]

    return tokens


def load_filter(path: str, device: str = "cuda:0"):
    import json

    import torch

    with open(path) as f:
        filter = json.load(f)

    return {key: torch.tensor(value, device=device) for key, value in filter.items()}




def load_tokenizer(model):
    """
    Loads tokenizer to the default NNsight configuration.
    """

    tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
    tokenizer._pad_token = tokenizer._eos_token

    return tokenizer
