# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest

import numpy as np

from transformers import AutoTokenizer, BarkProcessor
from transformers.testing_utils import require_torch, slow


@require_torch
class BarkProcessorTest(unittest.TestCase):
    def setUp(self):
        self.checkpoint = "suno/bark-small"
        self.tmpdirname = tempfile.mkdtemp()
        self.voice_preset = "en_speaker_1"
        self.input_string = "This is a test string"
        self.speaker_embeddings_dict_path = "speaker_embeddings_path.json"
        self.speaker_embeddings_directory = "speaker_embeddings"

    def get_tokenizer(self, **kwargs):
        return AutoTokenizer.from_pretrained(self.checkpoint, **kwargs)

    def tearDown(self):
        shutil.rmtree(self.tmpdirname)

    def test_save_load_pretrained_default(self):
        tokenizer = self.get_tokenizer()

        processor = BarkProcessor(tokenizer=tokenizer)

        processor.save_pretrained(self.tmpdirname)
        processor = BarkProcessor.from_pretrained(self.tmpdirname)

        self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())

    @slow
    def test_save_load_pretrained_additional_features(self):
        processor = BarkProcessor.from_pretrained(
            pretrained_processor_name_or_path=self.checkpoint,
            speaker_embeddings_dict_path=self.speaker_embeddings_dict_path,
        )
        processor.save_pretrained(
            self.tmpdirname,
            speaker_embeddings_dict_path=self.speaker_embeddings_dict_path,
            speaker_embeddings_directory=self.speaker_embeddings_directory,
        )

        tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")

        processor = BarkProcessor.from_pretrained(
            self.tmpdirname,
            self.speaker_embeddings_dict_path,
            bos_token="(BOS)",
            eos_token="(EOS)",
        )

        self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())

    def test_speaker_embeddings(self):
        processor = BarkProcessor.from_pretrained(
            pretrained_processor_name_or_path=self.checkpoint,
            speaker_embeddings_dict_path=self.speaker_embeddings_dict_path,
        )

        seq_len = 35
        nb_codebooks_coarse = 2
        nb_codebooks_total = 8

        voice_preset = {
            "semantic_prompt": np.ones(seq_len),
            "coarse_prompt": np.ones((nb_codebooks_coarse, seq_len)),
            "fine_prompt": np.ones((nb_codebooks_total, seq_len)),
        }

        # test providing already loaded voice_preset
        inputs = processor(text=self.input_string, voice_preset=voice_preset)

        processed_voice_preset = inputs["history_prompt"]
        for key in voice_preset:
            self.assertListEqual(voice_preset[key].tolist(), processed_voice_preset.get(key, np.array([])).tolist())

        # test loading voice preset from npz file
        tmpfilename = os.path.join(self.tmpdirname, "file.npz")
        np.savez(tmpfilename, **voice_preset)
        inputs = processor(text=self.input_string, voice_preset=tmpfilename)
        processed_voice_preset = inputs["history_prompt"]

        for key in voice_preset:
            self.assertListEqual(voice_preset[key].tolist(), processed_voice_preset.get(key, np.array([])).tolist())

        # test loading voice preset from the hub
        inputs = processor(text=self.input_string, voice_preset=self.voice_preset)

    def test_tokenizer(self):
        tokenizer = self.get_tokenizer()

        processor = BarkProcessor(tokenizer=tokenizer)

        encoded_processor = processor(text=self.input_string)

        encoded_tok = tokenizer(
            self.input_string,
            padding="max_length",
            max_length=256,
            add_special_tokens=False,
            return_attention_mask=True,
            return_token_type_ids=False,
        )

        for key in encoded_tok.keys():
            self.assertListEqual(encoded_tok[key], encoded_processor[key].squeeze().tolist())
