import copy
from tqdm import tqdm


def fill_none_with_empty_string(example):
    return {key: (value if value is not None else "") for key, value in example.items()}


def preprocess_function(examples, tokenizer, max_input_length, messages_template):
    # deepcopy the template
    messages = copy.deepcopy(messages_template)
    messages[1]["content"] = messages[1]["content"].format(**examples)
    messages[2]["content"] = messages[2]["content"].format(**examples)
    try:
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
        )
    except Exception as e:
        # print(f"Error applying chat template: {e}")
        prompt = "\n".join(
            [f"{message['role']}: {message['content']}" for message in messages]
        )

    model_inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_input_length,
    )
    model_inputs["text"] = prompt
    model_inputs["input_ids"] = model_inputs["input_ids"].squeeze()
    model_inputs["attention_mask"] = model_inputs["attention_mask"].squeeze()
    model_inputs["messages"] = messages
    return model_inputs


def preprocess_dataset(
    dataset, tokenizer, max_input_length, messages_template, process_messages=True
):
    dataset_a = dataset.map(fill_none_with_empty_string)
    dataset = dataset_a
    if process_messages:
        dataset = dataset.map(
            lambda x: preprocess_function(
                x, tokenizer, max_input_length, messages_template
            ),
            batched=False,
        )

        dataset.set_format(
            "pt",
            columns=["input_ids", "attention_mask"],
            output_all_columns=True,
        )
    return dataset


def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text = example["text"]
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))

    return total_characters / total_tokens
