import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import torch
import transformers
from transformers import AutoTokenizer, AutoFeatureExtractor
from transformers.utils import logging

from src.arguments import parse_args
from src.dataset.datacollator import DataCollatorForSpeechRecognition
from src.dataset.speech_recognition_dataset import SpeechRecognitionDataset
from src.model.audio_model.modeling_alignchat import AlignChatForConditionalGeneration
from src.trainer.alignchat_trainer import AlignChatRecognitionTrainer
from src.utils.utils import MODEL_DIR, QWEN2_START_TOKEN, QWEN2_END_TOKEN

logger = logging.get_logger(__name__)


def main():

    if 'LOCAL_RANK' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)
    else:
        local_rank = 0

    transformers.utils.logging.set_verbosity_info()
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    if local_rank != 0: # output only on rank 0
        transformers.utils.logging.set_verbosity_error()

    args = parse_args()
    device = args.device

    # load audio model
    audio_model = AlignChatForConditionalGeneration.from_pretrained(args.audio_model_path, device_map=device)
    
    audio_model = audio_model.train()
    audio_model._set_logging_file(os.path.join(args.output_dir, 'multi_loss.log'))
    
    # load embeddings
    embed_tokens = torch.load(os.path.join(MODEL_DIR, f'alignchat/{args.audio_model_type}/embed_tokens.pt'), map_location=device, weights_only=False)
    embed_tokens.requires_grad_(False)

    proj_out = torch.load(os.path.join(MODEL_DIR, f'alignchat/{args.audio_model_type}/proj_out.pt'), map_location=device, weights_only=False)
    proj_out.requires_grad_(False)

    # load processor
    extractor = AutoFeatureExtractor.from_pretrained(os.path.join(MODEL_DIR, f'alignchat/{args.audio_model_type}'))
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(MODEL_DIR, f'alignchat/{args.audio_model_type}'), add_eos_token=False, use_fast=False)


    # freeze encoder
    if args.freeze_encoder:
        audio_model.freeze_encoder()

    # freeze decoder
    if args.freeze_decoder:
        for param in audio_model.model.decoder.parameters():
            param.requires_grad = False

    # make sure new layers are trainable
    audio_model.model.decoder.a2t_proj.requires_grad_(True)
    audio_model.model.decoder.t2a_proj.requires_grad_(True)
    audio_model.modality_projector.requires_grad_(True)

    # make sure inputs require gradient
    if args.gradient_checkpointing:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)

        embed_tokens.register_forward_hook(
            make_inputs_require_grad)


    # create dataset
    train_dataset = SpeechRecognitionDataset(
        dataset_paths=args.dataset_paths,
        ratios=args.dataset_ratios,
        start_token=QWEN2_START_TOKEN,
        end_token=QWEN2_END_TOKEN,
        extractor=extractor,
        tokenizer=tokenizer,
        split=args.dataset_split,
    )
    data_collator = DataCollatorForSpeechRecognition(extractor=extractor, tokenizer=tokenizer)

    # create trainer
    trainer = AlignChatRecognitionTrainer(
        args=args,
        model=audio_model,
        train_dataset=train_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        embed_tokens=embed_tokens,
        proj_out=proj_out,
    )

    trainer.train(args.resume_from_checkpoint)
    logger.info("-*-" * 25 + "\nTraining completed! Congratulations!")


if __name__ == "__main__":
    main()
