import os
import unittest

from swift.llm import (ModelType, get_default_template_type,
                       get_model_tokenizer, get_template, inference,
                       inference_stream, limit_history_length, print_example)
from swift.utils import lower_bound, seed_everything


class TestLlmUtils(unittest.TestCase):

    def test_count_startswith(self):
        arr = [-100] * 1000 + list(range(1000))
        self.assertTrue(
            lower_bound(0, len(arr), lambda i: arr[i] != -100) == 1000)

    def test_count_endswith(self):
        arr = list(range(1000)) + [-100] * 1000
        self.assertTrue(
            lower_bound(0, len(arr), lambda i: arr[i] == -100) == 1000)

    def test_inference(self):
        model_type = ModelType.chatglm2_6b
        model, tokenizer = get_model_tokenizer(model_type)
        template_type = get_default_template_type(model_type)
        template = get_template(template_type, tokenizer)
        model.generation_config.max_length = 128
        model.generation_config.do_sample = True
        for query in ['你好', 'hello']:
            seed_everything(42, True)
            print('stream=True')
            gen_text_stream, history = inference(
                model, template, query, stream=True, verbose=True)
            print(f'[GEN]: {gen_text_stream}')
            print(f'[HISTORY]: {history}')
            #
            seed_everything(42, True)
            gen = inference_stream(model, template, query)
            for gen_text_stream2, history2 in gen:
                pass
            print(f'[GEN]: {gen_text_stream2}')
            print(f'[HISTORY]: {history2}')
            #
            seed_everything(42, True)
            print('stream=False')
            gen_text, history3 = inference(
                model, template, query, stream=False, verbose=True)
            print(f'[GEN]: {gen_text}')
            print(f'[HISTORY]: {history3}')
            self.assertTrue(gen_text_stream == gen_text_stream2 == gen_text)
            self.assertTrue(history == history2 == history3)

    def test_print_example(self):
        input_ids = [1000, 2000, 3000, 4000, 5000, 6000]
        _, tokenizer = get_model_tokenizer(
            ModelType.chatglm3_6b, load_model=False)
        from swift.llm.utils.utils import safe_tokenizer_decode
        labels = [-100, -100, 1000, 2000, 3000, -100, -100, 4000, 5000, 6000]
        print_example({'input_ids': input_ids, 'labels': labels}, tokenizer)
        assert safe_tokenizer_decode(
            tokenizer, labels
        ) == '[-100 * 2]before States appe[-100 * 2]innov developingishes'
        labels = [-100, -100, -100]
        print_example({'input_ids': input_ids, 'labels': labels}, tokenizer)
        assert safe_tokenizer_decode(tokenizer, labels) == '[-100 * 3]'
        labels = [1000, 2000, 3000, 4000, 5000, 6000]
        print_example({'input_ids': input_ids, 'labels': labels}, tokenizer)
        assert safe_tokenizer_decode(
            tokenizer, labels) == 'before States appe innov developingishes'

    def test_limit_history_length(self):
        model_type = ModelType.qwen_7b_chat
        _, tokenizer = get_model_tokenizer(model_type, load_model=False)
        template_type = get_default_template_type(model_type)
        template = get_template(template_type, tokenizer)
        old_history, new_history = limit_history_length(
            template, '你' * 100, [], 128)
        self.assertTrue(len(old_history) == 0 and len(new_history) == 0)
        old_history, new_history = limit_history_length(
            template, '你' * 100, [], 256)
        self.assertTrue(len(old_history) == 0 and len(new_history) == 0)
        self.assertTrue(len(tokenizer.encode('你' * 100)))
        old_history, new_history = limit_history_length(
            template, '你' * 100, [['你' * 100, '你' * 100] for i in range(5)],
            600)
        self.assertTrue(len(old_history) == 3 and len(new_history) == 2)


if __name__ == '__main__':
    unittest.main()
