from src.trainer.trainer import LLAMA_Trainer, LLAVA_Trainer, LLAVA_Trainer_NO_GIST, \
    LLAMA_Trainer_NO_GIST, QWen_Trainer
from src.models.llama_models import LlamaPromptTuningForClassification, llama_prompter
from src.models.llava_models import LLaVAPromptTuningForClassification, llava_prompter
from src.models.llava_models_no_gist import LLaVAPromptTuningForClassification_NOGIST
from src.models.llama_models_no_gist import LlamaPromptTuningForClassification_NOGIST
from src.models.qwenvl_models import Qwen2VLPromptTuningForClassification, qwenvl_prompter
from src.data.datasets import construct_detection_dataset, DataCollator
import torch
from transformers import TrainingArguments, AutoTokenizer, \
    LlavaNextProcessor, AutoProcessor
import os
import json
import argparse
from src.utils.utils import DictToObject
os.environ["TOKENIZERS_PARALLELISM"] = 'true'
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str,
                        default="configs/train_configs/llava-mistral-no-gist.json")
    # parser.add_argument("--config_path", type=str,
    #                     default="configs/train_configs/llava1.6-7b-0-no-gist.json")
    # parser.add_argument("--config_path", type=str,
    #                     default="configs/train_configs/llava1.6-7b-0-no-gist.json")

    return parser.parse_args()

def main(args):
    cfg = json.load(open(args.config_path))
    cfg = DictToObject(cfg)


    if cfg.model_zoo == 'llama':
        prompter = llama_prompter
        processor = AutoTokenizer.from_pretrained(cfg.pretrained_model_path)
        model = LlamaPromptTuningForClassification.from_pretrained(cfg.pretrained_model_path,
                                                                   # attn_implementation="flash_attention_2",
                                                                   torch_dtype=torch.bfloat16,
                                                                   device_map={
                                                                       "": int(os.environ.get("LOCAL_RANK") or 0)})
        model.add_special_token(processor)
    elif cfg.model_zoo == 'llava':
        processor = LlavaNextProcessor.from_pretrained(cfg.pretrained_model_path)
        processor.tokenizer.padding_side = "left"
        model = LLaVAPromptTuningForClassification.from_pretrained(cfg.pretrained_model_path,
                                                                   # attn_implementation="flash_attention_2",
                                                                   torch_dtype=torch.bfloat16,
                                                                   device_map={
                                                                       "": int(os.environ.get("LOCAL_RANK") or 0)})
        model.padding_side = "left"
        model.add_special_token(processor)
        prompter = llava_prompter(processor)
    data_collator = DataCollator(processor, max_length=3072)
    train_dataset, test_dataset = construct_detection_dataset(cfg.dataset_path,
                                                              prompter=prompter,
                                                              image_root='datasets/VQA/coco/train2017')


    # 只冻结除分类头和特殊 token 以外的所有参数
    model.configure_model()
    def compute_metrics(eval_pred):
        print(f'eval_pred: {eval_pred}')
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        preds = predictions.argmax(-1)  # 对分类问题使用argmax
        accuracy = (preds == labels).mean()
        return {"accuracy": accuracy}

    training_args = TrainingArguments(
        output_dir=cfg.log_dir,
        evaluation_strategy="steps",
        eval_steps=cfg.logging_steps,
        learning_rate=cfg.learning_rate,
        per_device_train_batch_size=cfg.train_batch_size,
        per_device_eval_batch_size=cfg.eval_batch_size,
        dataloader_num_workers=0,
        gradient_accumulation_steps=2,
        num_train_epochs=cfg.num_train_epochs,  # 可以增加到 3
        weight_decay=1e-5,
        logging_dir=cfg.log_dir,
        logging_steps=cfg.logging_steps,
        bf16=True,
        deepspeed='configs/deepspeed_configs/zero2.json'
    )
    if cfg.model_zoo == 'llama':
        trainer = LLAMA_Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            # processing_class=processor,
            compute_metrics=compute_metrics,
        )
    elif cfg.model_zoo == 'llama_ng':
        trainer = LLAMA_Trainer_NO_GIST(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            # processing_class=processor,
            compute_metrics=compute_metrics,
        )
    elif cfg.model_zoo == 'llava':
        trainer = LLAVA_Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            # processing_class=processor,
            compute_metrics=compute_metrics,
        )
    elif cfg.model_zoo == 'qwen2vl':
        trainer = QWen_Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            # processing_class=processor,
            compute_metrics=compute_metrics,
        )
    elif cfg.model_zoo == 'llava_ng':
        trainer = LLAVA_Trainer_NO_GIST(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            # processing_class=processor,
            compute_metrics=compute_metrics,
        )
    trainer.train()
    # trainer.evaluate()
    model.save_pretrained(cfg.model_save_dir)

if __name__ == "__main__":
    args = get_args()
    main(args)




