import argparse
import os

from transformers import AutoConfig

from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.")
    args = parser.parse_args()

    text_model = "google-t5/t5-small"
    encodec_version = "facebook/encodec_24khz"

    t5 = AutoConfig.from_pretrained(text_model)
    encodec = AutoConfig.from_pretrained(encodec_version)

    encodec_vocab_size = encodec.codebook_size
    num_codebooks = 8
    print("num_codebooks", num_codebooks)

    decoder_config = ParlerTTSDecoderConfig(
        vocab_size=encodec_vocab_size + 1,
        max_position_embeddings=2048,
        num_hidden_layers=4,
        ffn_dim=512,
        num_attention_heads=8,
        layerdrop=0.0,
        use_cache=True,
        activation_function="gelu",
        hidden_size=512,
        dropout=0.0,
        attention_dropout=0.0,
        activation_dropout=0.0,
        pad_token_id=encodec_vocab_size,
        eos_token_id=encodec_vocab_size,
        bos_token_id=encodec_vocab_size + 1,
        num_codebooks=num_codebooks,
    )

    decoder = ParlerTTSForCausalLM(decoder_config)

    decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))

    model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
        text_encoder_pretrained_model_name_or_path=text_model,
        audio_encoder_pretrained_model_name_or_path=encodec_version,
        decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"),
        vocab_size=t5.vocab_size,
    )

    # set the appropriate bos/pad token ids
    model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
    model.generation_config.pad_token_id = encodec_vocab_size
    model.generation_config.eos_token_id = encodec_vocab_size

    # set other default generation config params
    model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
    model.generation_config.do_sample = True  # True


    model.config.pad_token_id = encodec_vocab_size
    model.config.decoder_start_token_id = encodec_vocab_size + 1

    model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))
