# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
import numpy as np
from examples.textless_nlp.gslm.unit2speech.tacotron2.text import (
    EOS_TOK,
    SOS_TOK,
    code_to_sequence,
    text_to_sequence,
)
from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import (
    load_code_dict,
)


class TacotronInputDataset:
    def __init__(self, hparams, append_str=""):
        self.is_text = getattr(hparams, "text_or_code", "text") == "text"
        if not self.is_text:
            self.code_dict = load_code_dict(hparams.code_dict)
            self.code_key = hparams.code_key
        self.add_sos = hparams.add_sos
        self.add_eos = hparams.add_eos
        self.collapse_code = hparams.collapse_code
        self.append_str = append_str

    def process_code(self, inp_str):
        inp_toks = inp_str.split()
        if self.add_sos:
            inp_toks = [SOS_TOK] + inp_toks
        if self.add_eos:
            inp_toks = inp_toks + [EOS_TOK]
        return code_to_sequence(inp_toks, self.code_dict, self.collapse_code)

    def process_text(self, inp_str):
        return text_to_sequence(inp_str, ["english_cleaners"])

    def get_tensor(self, inp_str):
        # uid, txt, inp_str = self._get_data(idx)
        inp_str = inp_str + self.append_str
        if self.is_text:
            inp_toks = self.process_text(inp_str)
        else:
            inp_toks = self.process_code(inp_str)
        return torch.from_numpy(np.array(inp_toks)).long()

    def __len__(self):
        return len(self.data)
