import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import torch
import transformers
from transformers import AutoTokenizer, AutoFeatureExtractor
from transformers.utils import logging

from src.arguments import parse_args
from src.model.alignchat_model import AlignChatForEndToEndResponse
from src.dataset.speech_response_dataset import SpeechResponseDataset
from src.dataset.datacollator import DataCollatorForSpeechResponse
from src.trainer.alignchat_trainer import AlignChatResponseTrainer
from src.utils.utils import MODEL_DIR, QWEN2_START_TOKEN, QWEN2_END_TOKEN, get_text_model_path

logger = logging.get_logger(__name__)


def main():

    if 'LOCAL_RANK' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)
    else:
        local_rank = 0

    transformers.utils.logging.set_verbosity_info()
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    if local_rank != 0: # output only on rank 0
        transformers.utils.logging.set_verbosity_error()

    args = parse_args()
    device = args.device

    audio_model_type = args.audio_model_type
    audio_model_path = args.audio_model_path

    text_model_type = audio_model_type.split('-')[-1]
    text_model_path = get_text_model_path(text_model_type)
    
    # load model
    torch_dtype = torch.bfloat16 if args.bf16 else None
    model = AlignChatForEndToEndResponse(
        audio_model_path=audio_model_path,
        audio_model_type=audio_model_type,
        text_model_path=text_model_path,
        device_map=device,
        torch_dtype=torch_dtype,
        attn_implementation='flash_attention_2',
    )
    
    model.audio_model.train()
    model._set_logging_file(os.path.join(args.output_dir, 'multi_loss.log'))

    
    # freeze text model
    model.text_model.eval()
    model.text_model.requires_grad_(False)
    for _, param in model.text_model.named_parameters():
        param.requires_grad_(False)

    # load processor
    extractor = AutoFeatureExtractor.from_pretrained(os.path.join(MODEL_DIR, f'alignchat/{args.audio_model_type}'))
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(MODEL_DIR, f'alignchat/{args.audio_model_type}'), add_eos_token=False, use_fast=False)


    # freeze encoder
    if args.freeze_encoder:
        model.audio_model.freeze_encoder()

    # freeze decoder
    if args.freeze_decoder:
        for param in model.audio_model.model.decoder.parameters():
            param.requires_grad = False

    # make sure new layers are trainable
    model.audio_model.model.decoder.a2t_proj.requires_grad_(True)
    model.audio_model.model.decoder.t2a_proj.requires_grad_(True)
    model.audio_model.modality_projector.requires_grad_(True)

    # make sure inputs require gradient
    if args.gradient_checkpointing:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)

        model.embed_tokens.register_forward_hook(
            make_inputs_require_grad)


    # create dataset
    train_dataset = SpeechResponseDataset(
        dataset_paths=args.dataset_paths,
        ratios=args.dataset_ratios,
        start_token=QWEN2_START_TOKEN,
        end_token=QWEN2_END_TOKEN,
        extractor=extractor,
        tokenizer=tokenizer,
        split=args.dataset_split,
    )
    data_collator = DataCollatorForSpeechResponse(extractor=extractor, tokenizer=tokenizer)

    # create trainer
    trainer = AlignChatResponseTrainer(
        args=args,
        model=model,
        train_dataset=train_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    trainer.train(args.resume_from_checkpoint)
    logger.info("-*-" * 25 + "\nTraining completed! Congratulations!")


if __name__ == "__main__":
    main()
