import os
import yaml
import torch
import concurrent.futures
from tqdm import tqdm
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from datasets import load_dataset
from baselines import load_model_and_tokenizer
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Setting data directory for Dolma datasets

# Load dataset
dataset = load_dataset("dolma/dolma.py", split="train")
dataset = [(item['text'], item['source']) for item in dataset]
logging.info('Loaded texts')

# Model configuration setup
model_name = "llama2_7b_fast"
config_file = "configs/model_configs/models.yaml"

with open(config_file) as file:
    model_configs = yaml.full_load(file)

num_gpus = model_configs[model_name].get('num_gpus', torch.cuda.device_count())
model_config = model_configs[model_name]['model']
model_config['num_gpus'] = num_gpus

model, tokenizer = load_model_and_tokenizer(**model_config)
logging.info('Model and tokenizer loaded')

# Function to tokenize a single document
def tokenize_function(text):
    try:
        return tokenizer.tokenize(text)
    except Exception as e:
        logging.error(f"Error tokenizing text: {e}")
        return []

# Saving tokenized outputs efficiently
def save_tokenized_data(token_lists, sources_list, file_name='tokenized_data_all.parquet'):
    try:
        df = pd.DataFrame({'tokens': token_lists, 'source': sources_list})
        table = pa.Table.from_pandas(df, preserve_index=False)

        # Using write_to_dataset for handling existing files and schema management
        pq.write_to_dataset(table, root_path=file_name, compression='snappy')
    except Exception as e:
        logging.error(f"Failed to save data: {e}")

# multiprocessing and batch in mapcall von huggingface
def process_data_in_batches(batch_size=1000):
    with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
        for i in tqdm(range(0, len(dataset), batch_size), desc="Processing batches"):
            batch_texts = [item[0] for item in dataset[i:i + batch_size]]
            batch_sources = [item[1] for item in dataset[i:i + batch_size]]
            token_lists = list(executor.map(tokenize_function, batch_texts))
            if len(batch_sources) == len(token_lists):
                save_tokenized_data(token_lists, batch_sources)
            else:
                logging.warning("Mismatch between token lists and source lists sizes.")

process_data_in_batches()
