# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import os
import pathlib

import torch
import transformers
from transformers import Trainer
from llava.training_module.load_args import ModelArguments, DataArguments, TrainingArguments, OtherArguments, VisionModuleArguments, VisionModuleArguments_with_vm_prefix
from llava.training_module.dataset import make_supervised_data_module
from llava.training_module.utils import safe_save_model_for_hf_trainer, update_mm_projector
from llava.model.builder import load_pretrained_vision_module
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from llava.model.builder import load_pretrained_model_v2, load_pretrained_model, load_pretrained_model_v3

def train():
    global local_rank

    ## load args
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, OtherArguments, VisionModuleArguments_with_vm_prefix))
    model_args, data_args, training_args, args, vm_args = parser.parse_args_into_dataclasses()
    # rename all agrs in vm_args that delete vm_ prefix in the name
    vm_args = VisionModuleArguments(**{k.replace('vm_', ''): v for k, v in vm_args.__dict__.items()})

    device = torch.device("cuda", training_args.local_rank)

    ## load model
    model_base = None if args.model_base is None else args.model_base
    model_path = model_args.model_name_or_path if model_base is None else args.model_path
    model_name = get_model_name_from_path(model_path)

    # load vanilla llava model
    # tokenizer, model0, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 

    # tokenizer, model, image_processor, context_len = load_pretrained_model_v2(model_args, data_args, training_args, model_path, model_base, model_name)
    tokenizer, model, image_processor, context_len = load_pretrained_model_v3(model_args, data_args, training_args, "detr-v2")
    data_args.image_processor = image_processor

    ## load aligned vision module
    aligned_vision_tokenizer, aligned_vision_model, aligned_vision_image_processor, aligned_vision_context_len = load_pretrained_vision_module(model_args, vm_args, data_args, training_args)

    ## load data
    print("loading data")
    data_module = make_supervised_data_module(tokenizer, data_args)

    # check trainable parameters
    from llava.training_module.utils import print_trainable_layers, print_model_size
    print_trainable_layers(model)

    # the model size
    print_model_size(model)
    # data: image_tensor
    # cfg
    # use logits_processor
    trainer = Trainer(model=model,
                    args=training_args,
                    **data_module)

    # print("training arguments:")
    # print(training_args)

    ### train
    print("start training")
    print(list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")))

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        print("resume from checkpoint")
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()
    print("finish training")

    # #### save model
    safe_save_model_for_hf_trainer(trainer=trainer,
                                       output_dir=training_args.output_dir)

    #### evaluate
    if data_module['eval_dataset'] is not None:
        print("start evaluating")
        metrics = trainer.evaluate()
        print(metrics)
        print("finish evaluating")




if __name__ == "__main__":
    from llava.training_module.utils import set_wandb_dir
    abs_path = os.path.dirname(os.path.abspath(__file__))
    set_wandb_dir(abs_path)

    train()
