
from datasets import load_dataset
from src.data_utils import insert_special_token


def prepare_xnli_data(tokenizer, lang_id, config=None):
    dataset = load_dataset("xnli", lang_id)
    script_id = tokenizer.convert_tokens_to_ids(config['language_to_script'].get(lang_id))
    def preprocess_function(examples):
        tokenized = tokenizer(examples["premise"], examples["hypothesis"], truncation=True, padding="max_length", max_length=config['seq_len'])
        return tokenized
    def add_script_id(examples, script_id=None):
        if script_id is not None:
            return insert_special_token(examples, script_id)
        return examples
    encoded_train = dataset['train'].map(preprocess_function, batched=True ,remove_columns=["premise", "hypothesis"])
    encoded_validation = dataset['validation'].map(preprocess_function, batched=True, remove_columns=["premise", "hypothesis"]) 
    encoded_test = dataset['test'].map(preprocess_function, batched=True, remove_columns=["premise", "hypothesis"] ) if 'test' in dataset else None
    encoded_train = encoded_train.map(add_script_id, fn_kwargs={"script_id": script_id})   
    encoded_validation = encoded_validation.map(add_script_id, fn_kwargs={"script_id": script_id})
    encoded_test = encoded_test.map(add_script_id, fn_kwargs={"script_id": script_id}) if encoded_test is not None else None
    return encoded_train, encoded_validation, encoded_test

def prepare_arc_easy_data(tokenizer, lang_id, config=None):
    dataset = load_dataset("ai2_arc", "ARC-Easy")
    def preprocess_function(examples):
        return tokenizer(examples["question"], truncation=True, padding="max_length", max_length=config['seq_len'])
    encoded_train = dataset['train'].map(preprocess_function, batched=True)
    encoded_validation = dataset['validation'].map(preprocess_function, batched=True) 
    encoded_test = dataset['test'].map(preprocess_function, batched=True) if 'test' in dataset else None
    return encoded_train, encoded_validation, encoded_test



