import dataclasses
from typing import Optional

import pytest
from transformers import AutoTokenizer

import fishfarm
from fishfarm import Message
from fishfarm.models.tokenization_utils import tokenize_messages


@pytest.mark.parametrize(
    "chat_template",
    [
        fishfarm.chat_templates.ALPACA_JA,
        fishfarm.chat_templates.LLAMA2,
        None,
    ],
)
def test_tokenize_and_mask(chat_template: Optional[str]) -> None:
    tokenizer = AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B")

    messages = [
        Message(role="system", content="system system system"),
        Message(role="user", content="user1 user1"),
        Message(role="assistant", content="completion1 completion1"),
        Message(role="user", content="USER2 USER2"),
        Message(role="assistant", content="COMPLETION2 COMPLETION2"),
        Message(role="user", content="user3 user3"),
    ]

    masked_tokens = tokenize_messages(messages, tokenizer, chat_template)

    assert len(masked_tokens.token_ids) > 0
    assert len(masked_tokens.token_ids) == len(masked_tokens.mask)

    #
    # Checking mask values
    #

    # The first token
    assert masked_tokens.mask[0] is False

    # Occurrence of [False, True]
    num_false_to_true = sum(
        [masked_tokens.mask[i : i + 2] == [False, True] for i in range(len(masked_tokens.mask))]
    )
    assert num_false_to_true == 2

    # Occurence of [True, False]
    num_true_to_false = sum(
        [masked_tokens.mask[i : i + 2] == [True, False] for i in range(len(masked_tokens.mask))]
    )
    assert num_true_to_false == 2

    # The last token
    assert masked_tokens.mask[-1] is False

    #
    # Checking text
    #
    text = tokenizer.apply_chat_template(
        conversation=[dataclasses.asdict(message) for message in messages],
        chat_template=chat_template,
        tokenize=False,
        add_generation_prompt=True,
    )
    assert masked_tokens.text == text
    assert tokenizer.decode(masked_tokens.token_ids) == text

    #
    # Checking correspondence between texts and mask
    #
    token_ids_completion = []
    token_ids_other = []
    for mask_value, token_id in zip(masked_tokens.mask, masked_tokens.token_ids):
        if mask_value:
            token_ids_completion.append(token_id)
        else:
            token_ids_other.append(token_id)

    text_completion = tokenizer.decode(token_ids_completion)
    assert "completion1 completion1" in text_completion
    assert "COMPLETION2 COMPLETION2" in text_completion

    text_other = tokenizer.decode(token_ids_other)
    assert "system system system" in text_other
    assert "user1 user1" in text_other
    assert "USER2 USER2" in text_other
    assert "user3 user3" in text_other
