from omegaconf import OmegaConf
import argparse

config_parser = argparse.ArgumentParser()
config_parser.add_argument("-c", "--config_path", type=str, default="v2/config/config-llava-mistral-eval.yaml")
config_parser.add_argument("-g", "--gpus", type=str, default="0,1,2,3")
config_args = config_parser.parse_args()
config_path = config_args.config_path
config = OmegaConf.load(config_path)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = config_args.gpus

import sys
if config.transformers_args.my_lora:
    from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
    import transformers as transformers
    from peft import LoraConfig, get_peft_model, PeftModel
    import peft
else:
    sys.path.remove(os.path.dirname(__file__))
    sys.path.append(os.path.dirname(__file__))
    from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
    import transformers
    from peft import LoraConfig, get_peft_model, PeftModel
    import peft

import copy
import torch.distributed
from dataclasses import dataclass, field
import json
import evaluate

from hook_model import HookModel
import torch
import datasets
from PIL import Image
import copy
# from peft import prepare_model_for_kbit_training
from datetime import datetime
from arguments import ModelArguments, DataArguments, TrainingArguments
import numpy as np
import sys
import bitsandbytes as bnb
import re
import jieba
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge
from nltk.stem import PorterStemmer
from tqdm import tqdm
import deepspeed
import shutil
import pandas as pd
import logging
import random
from vlmeval.dataset import ImageMCQDataset, MathVision
import string
from transformers import AutoModel

torch.backends.cudnn.benchmark = True

DATASETS_MAPPING = {
    "slake_vqa": "data/Slake1.0",
    "iuxray_rg": "data/iu-xray",
    "ocrbench": "data/ocrbench",
}

# from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live

def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
    r"""
    This method wraps the entire protocol for preparing a model before running a training. This includes:
        1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
        head to fp32

    Args:
        model, (`transformers.PreTrainedModel`):
            The loaded model from `transformers`
    """
    loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)

    for name, param in model.named_parameters():
        # freeze base model's layers
        param.requires_grad = False

    # cast all non INT8/INT4 parameters to fp32
    for param in model.parameters():
        if ((param.dtype == torch.float16) or (param.dtype == torch.bfloat16)) and loaded_in_kbit:
            param.data = param.data.to(torch.float32)

    for name, module in model.named_modules():
        if 'norm' in name:
            module = module.to(torch.float32)

    if loaded_in_kbit and use_gradient_checkpointing:
        # For backward compatibility
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, _input, output):
                output.requires_grad_(True)

            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
        # enable gradient checkpointing for memory efficiency
        model.gradient_checkpointing_enable()

    return model

class DualOutput:
    def __init__(self, file_name):
        self.file = open(file_name, "a")
        self.stdout = sys.stdout 
        self.stderr = sys.stderr 
    
    def write(self, message):
        self.stdout.write(message)
        self.file.write(message)      

    def flush(self):
        self.stdout.flush()
        self.file.flush()
    
def set_dual_output(output_txt_path):
    dual_logger = DualOutput(output_txt_path)
    sys.stdout = dual_logger
    sys.stderr = dual_logger
    logging.basicConfig(
            level=logging.DEBUG,
            format="%(message)s",
            handlers=[logging.StreamHandler(sys.stdout)] 
        )

def set_sft_trainable_parameters(model, training_args):
    if not training_args.use_peft:
        model.requires_grad_(False)
        key_list = [key for key, _ in model.named_modules()]
        for key in key_list:
            if isinstance(training_args.lora_target_modules, str):
                target_module_found = re.fullmatch(training_args.lora_target_modules, key)
            else:
                target_module_found = any(key.endswith(target_key) for target_key in training_args.lora_target_modules)
            if target_module_found:
                model.get_submodule(key).requires_grad_(True)

def calc_submodule_trainable_parameters(model):
    result = {}
    for name, module in model.named_children():
        result[name + "_trainable_params"] = (
            sum(p.numel() for p in module.parameters() if p.requires_grad) / 1e9
        )
        result[name + "_trainable_percent"] = sum(
            p.numel() for p in module.parameters() if p.requires_grad
        ) / sum(p.numel() for p in module.parameters()) if sum(p.numel() for p in module.parameters()) != 0 else 0
    result["total_trainable_params"] = (
        sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9
    )
    result["total_trainable_percent"] = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    ) / sum(p.numel() for p in model.parameters()) if sum(p.numel() for p in model.parameters())!= 0 else 0
    return result


class build_dataset:
    def __init__(self, data_args, processor):
        self.data_args = data_args
        self.data_dirs = data_args.data_dir
        self.formatted_train_dataset_list = []
        self.formatted_test_dataset_list = []
        self.processor = processor
        for dataset_name in self.data_dirs:
            if dataset_name.startswith(("opencompass", "hf")):
                formatted_train_dataset, formatted_test_dataset = getattr(
                    self, dataset_name
                )()
            else:
                dataset_dir = DATASETS_MAPPING[dataset_name]
                formatted_train_dataset, formatted_test_dataset = getattr(
                    self, dataset_name
                )(dataset_dir)
            self.formatted_train_dataset_list.append(formatted_train_dataset) #.select(range(10)))
            self.formatted_test_dataset_list.append(formatted_test_dataset) #.select(range(2)))
        self.formatted_train_dataset = datasets.concatenate_datasets(
            self.formatted_train_dataset_list
        )
        self.formatted_test_dataset = datasets.concatenate_datasets(
            self.formatted_test_dataset_list
        )

    def get_dataset(self):
        return self.formatted_train_dataset, self.formatted_test_dataset
    
    def hf_visonlyqa(self):
        print("####### dataset: hf_visonlyqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/ryokamoi/VisOnlyQA_Train
        ## citation: loading from https://huggingface.co/datasets/ryokamoi/VisOnlyQA_Eval_Synthetic
        train_dataset = datasets.load_dataset("ryokamoi/VisOnlyQA_Train", split="syntheticgeometry__triangle")
        eval_dataset = datasets.load_dataset("ryokamoi/VisOnlyQA_Eval_Synthetic", split="syntheticgeometry__triangle")
        train_dataset = train_dataset.map(lambda x: {"img_name": x["image_path"], "question": x["prompt_no_reasoning"], "image": x["decoded_image"]})
        eval_dataset = eval_dataset.map(lambda x: {"img_name": x["image_path"], "question": x["prompt_no_reasoning"], "image": x["decoded_image"]})
        # eval_dataset = eval_dataset.map(lambda x: {"img_name": x["image_path"], "image": x["decoded_image"], "question": "Answer the following question by \"True\" or \"False\":" + x["question"]})
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "visonlyqa"))
        formatted_eval_dataset = eval_dataset.map(lambda x: self.processing_func(x, "visonlyqa"))
        return formatted_train_dataset, formatted_eval_dataset
    
    def hf_pathvqa(self):
        print("####### dataset: hf_pathvqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/flaviagiammarino/path-vqa
        train_dataset = datasets.load_dataset("flaviagiammarino/path-vqa", split="train[:10000]")
        eval_dataset = datasets.load_dataset("flaviagiammarino/path-vqa", split="test")
        train_dataset = train_dataset.map(lambda x: {"img_name": "None"})
        eval_dataset = eval_dataset.map(lambda x: {"img_name": "None"})
        # eval_dataset = eval_dataset.map(lambda x: {"img_name": "None", "question": "Answer the following question by several words briefly. Example: question: where are liver stem cells (oval cells) located? answer: in the canals of hering\n The question is: " + x["question"]})
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "pathvqa"))
        formatted_eval_dataset = eval_dataset.map(lambda x: self.processing_func(x, "pathvqa"))
        return formatted_train_dataset, formatted_eval_dataset
    
    def hf_slakevqa(self):
        print("####### dataset: hf_slakevqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/mdwiratathya/SLAKE-vqa-english
        train_dataset = datasets.load_dataset("mdwiratathya/SLAKE-vqa-english", split="train")
        eval_dataset = datasets.load_dataset("mdwiratathya/SLAKE-vqa-english", split="test")
        train_dataset = train_dataset.map(lambda x: {"img_name": "None"})
        eval_dataset = eval_dataset.map(lambda x: {"img_name": "None"})
        # eval_dataset = eval_dataset.map(lambda x: {"img_name": "None", "question": "Answer the following question by several words briefly. Example: question: What modality is used to take this image? answer: MRI\n The question is: " + x["question"]})
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "slakevqa"))
        formatted_eval_dataset = eval_dataset.map(lambda x: self.processing_func(x, "slakevqa"))
        return formatted_train_dataset, formatted_eval_dataset

    def hf_textvqa(self):
        print("####### dataset: hf_textvqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/Multimodal-Fatima/TextVQA_train
        train_dataset = datasets.load_dataset("Multimodal-Fatima/TextVQA_train", split="train[:10000]")
        train_dataset = train_dataset.map(lambda x: {"img_name": x["image_id"], "answer": x["answers"][0]})
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "textvqa"))
        return formatted_train_dataset, formatted_train_dataset.select(range(16))
    
    def hf_a_okvqa(self):
        def a_okvqa_processing_func(x):
            idx_mapping_dict = {1:"A", 2:"B", 3:"C", 4:"D"}
            question = f'Question: {x["question"]}\nOptions:'
            for i, choice in enumerate(x['choices']):
                question += f'{idx_mapping_dict[i+1]}. {choice}\n'
            question += 'Please select the correct answer from the options above.'
            img_name = x["question_id"]
            answer = idx_mapping_dict[x["correct_choice_idx"]+1]
            return {"img_name": img_name, "question": question, "answer": answer}
        print("####### dataset: hf_a_okvqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/HuggingFaceM4/A-OKVQA
        train_dataset = datasets.load_dataset("HuggingFaceM4/A-OKVQA", split="train[:10000]")
        train_dataset = train_dataset.map(lambda x: a_okvqa_processing_func(x))
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "a_okvqa"))
        return formatted_train_dataset, formatted_train_dataset.select(range(16))
    
    # def hf_chartqa(self):
    #     print("####### dataset: hf_chartvqa #########")
    #     ## img_name, question, answer
    #     ## citation: loading from https://huggingface.co/datasets/HuggingFaceM4/ChartQA
    #     train_dataset = datasets.load_dataset("HuggingFaceM4/ChartQA", split="train[:10000]")
    #     train_dataset = train_dataset.map(lambda x: {"img_name": "none", "answer": x["label"][0], "question": x["query"]})
    #     formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "chartvqa"))
    #     return formatted_train_dataset, formatted_train_dataset.select(range(16))

    # def hf_docvqa(self):
    #     print("####### dataset: hf_docvqa #########")
    #     ## img_name, question, answer
    #     ## citation: loading from https://huggingface.co/datasets/lmms-lab/DocVQA/viewer/DocVQA?views%5B%5D=docvqa_validation
    #     train_dataset = datasets.load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[:5000]")
    #     train_dataset = train_dataset.map(lambda x: {"img_name": str(x["docId"]), "answer": x["answers"][0]})
    #     formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "docvqa"))
    #     return formatted_train_dataset, formatted_train_dataset.select(range(16))

    def hf_scienceqa(self):
        def sciencevqa_processing_func(x):
            idx_mapping_dict = {0:"A", 1:"B", 2:"C", 3:"D", 4:"E", 5:"F"}
            question = f'Hint: {x["hint"]}\nQuestion: {x["question"]}\nOptions:'
            for i, choice in enumerate(x['choices']):
                question += f'{idx_mapping_dict[i]}. {choice}\n'
            question += 'Please select the correct answer from the options above.'
            img_name = "none"
            answer = idx_mapping_dict[x["answer"]]
            return {"img_name": img_name, "question": question, "answer": answer}
        print("####### dataset: hf_sciencevqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/derek-thomas/ScienceQA
        train_dataset = datasets.load_dataset("derek-thomas/ScienceQA")["train"].filter(lambda x: x["image"] != None)
        train_dataset = train_dataset.map(lambda x: sciencevqa_processing_func(x))
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "sciencevqa"))
        return formatted_train_dataset, formatted_train_dataset.select(range(16))

    def hf_ocrvqa(self):
        def ocrvqa_processing_func(x):
            new_img_names = []
            new_images = []
            new_questions = []
            new_answers = []
            for image, questions, answers in zip(x["image"], x["questions"], x["answers"]):
                for q, a in zip(questions, answers):
                    new_img_names.append("none")
                    new_images.append(image)
                    new_questions.append(q)
                    new_answers.append(a)
            return {"img_name": new_img_names, "image": new_images, "question": new_questions, "answer": new_answers}
        print("####### dataset: hf_ocrvqa #########")
        ## img_name, question, answer
        ## citation: loading from https://huggingface.co/datasets/yobro4619/OCR-VQA_sample
        ## citation: loading from https://huggingface.co/datasets/howard-hou/OCR-VQA
        train_dataset = datasets.load_dataset("yobro4619/OCR-VQA_sample", split="train")
        train_dataset = train_dataset.map(ocrvqa_processing_func, batched=True, remove_columns=train_dataset.column_names)
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "ocrvqa"))
        return formatted_train_dataset, formatted_train_dataset.select(range(16))

    def opencompass_mmbench(self):
        print("####### dataset: opencompass_mmbench #########")
        ## img_name, question, answer
        opencompass_data = ImageMCQDataset("MMBench_DEV_EN_V11")
        for i in range(len(opencompass_data.data)):
            opencompass_data.dump_image(opencompass_data.data.iloc[i])
        opencompass_dataset = datasets.Dataset.from_pandas(opencompass_data.data)
        formatted_train_dataset = opencompass_dataset.map(
            lambda x: self.mmbench_processing_func(x, opencompass_data.img_root),
            remove_columns=opencompass_dataset.column_names,
        )
        return formatted_train_dataset, formatted_train_dataset.select(range(16))
    
    def ocrbench(self, data_dir):
        print("####### dataset: ocrbench #########")
        train_dataset = datasets.load_dataset(
            "json",
            data_files=os.path.join(data_dir, "ocrbench_train.json"),
        )["train"].filter(lambda example: example["type"] != "Chinese")
        train_dataset = train_dataset.map(lambda x: {"answer": x["answers"], "img_name": x["image_path"]})
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, os.path.join(data_dir, "data")))
        return formatted_train_dataset.select(range(5000)), formatted_train_dataset.select(range(16))
    
    def opencompass_mathvision(self):
        print("####### dataset: opencompass_mathvision #########")
        ## img_name, question, answer
        opencompass_data = MathVision('MathVision')
        for i in range(len(opencompass_data.data)):
            opencompass_data.dump_image(opencompass_data.data.iloc[i])
        opencompass_dataset = datasets.Dataset.from_pandas(opencompass_data.data)
        formatted_train_dataset = opencompass_dataset.map(
            lambda x: self.mathvision_processing_func(x, opencompass_data.img_root),
            remove_columns=opencompass_dataset.column_names,
        )
        return formatted_train_dataset, formatted_train_dataset.select(range(16))

    def hf_coco_caption(self):
        print("####### dataset: hf_coco_caption #########")
        ## citation: loading from https://huggingface.co/datasets/astro21/coco-caption-train-split-10k
        train_dataset = datasets.load_dataset("astro21/coco-caption-train-split-10k")["train[:10000]"]
        train_dataset = train_dataset.map(lambda x: {"answer": x["caption"], "img_name": str(x["image_id"]), "question": 'Please describe this image in general. Directly provide the description, do not include prefix like "This image depicts". '})
        formatted_train_dataset = train_dataset.map(lambda x: self.processing_func(x, "coco"))
        return formatted_train_dataset, formatted_train_dataset.select(range(16))

    # @staticmethod
    def slake_vqa(self, data_dir):
        print("####### dataset: slake_vqa #########")
        train_dataset = datasets.load_dataset(
            "json",
            data_files=os.path.join(data_dir, "train.json"),
        )["train"].filter(lambda example: example["q_lang"] == "en")
        test_dataset = datasets.load_dataset(
            "json",
            data_files=os.path.join(data_dir, "test.json"),
        )["train"].filter(lambda example: example["q_lang"] == "en")
        train_dataset = train_dataset.map(lambda x: {"question": "You should answer the following question accurately only in one word or phrase. The question is: " + x["question"]})
        test_dataset = test_dataset.map(lambda x: {"question": "You should answer the following question accurately only in one word or phrase. The question is: " + x["question"]})
        formatted_train_dataset = train_dataset.map(
            lambda x: self.processing_func(x, os.path.join(data_dir, "imgs")),
            remove_columns=train_dataset.column_names,
        )
        formatted_test_dataset = test_dataset.map(
            lambda x: self.processing_func(x, os.path.join(data_dir, "imgs")),
            remove_columns=test_dataset.column_names,
        )
        return formatted_train_dataset, formatted_test_dataset
    
    def iuxray_rg(self, data_dir):
        print("####### dataset: iuxray_rg #########")
        if not os.path.exists(os.path.join(data_dir, "train.csv")):
            print("processing raw iu-xray files and save train.csv and test.csv......")
            raw_projection_df = pd.read_csv(os.path.join(data_dir, "indiana_projections.csv"))
            raw_reports_df = pd.read_csv(os.path.join(data_dir, "indiana_reports.csv"))
            filtered_projection_df = raw_projection_df[raw_projection_df["projection"] == "Frontal"]
            filtered_reports_df = raw_reports_df[raw_reports_df["uid"].isin(filtered_projection_df["uid"])]
            filtered_reports_df = filtered_reports_df.dropna(subset=["findings"])
            new_df = filtered_reports_df[["uid", "MeSH", "Problems", "findings", "impression"]]
            dulplicated_projection_df = filtered_projection_df.loc[~filtered_projection_df["uid"].duplicated(keep="first")]
            dulplicated_projection_df.set_index("uid", inplace=True)
            new_df["img_name"] = new_df["uid"].map(dulplicated_projection_df["filename"])
            train_df, test_df = np.split(new_df, [int(.8*len(new_df))])
            train_df.reset_index(drop=True, inplace=True)
            test_df.reset_index(drop=True, inplace=True)
            train_df.to_csv(os.path.join(data_dir, "train.csv"))
            test_df.to_csv(os.path.join(data_dir, "test.csv"))
        train_df = pd.read_csv(os.path.join(data_dir, "train.csv"))[["findings", "img_name"]]
        test_df = pd.read_csv(os.path.join(data_dir, "test.csv"))[["findings", "img_name"]]
        train_df["question"] = "You are now a x-ray report assistant. Generate the report for the x-ray image I provided."
        test_df["question"] = "You are now a x-ray report assistant. Generate the report for the x-ray image I provided."
        train_df.columns = ["answer", "img_name", "question"]
        test_df.columns = ["answer", "img_name", "question"]
        train_dataset = datasets.Dataset.from_pandas(train_df)
        test_dataset = datasets.Dataset.from_pandas(test_df)
        formatted_train_dataset = train_dataset.map(
            lambda x: self.processing_func(x, os.path.join(data_dir, "images/images_normalized")),
            remove_columns=train_dataset.column_names,
        )
        formatted_test_dataset = test_dataset.map(
            lambda x: self.processing_func(x, os.path.join(data_dir, "images/images_normalized")),
            remove_columns=test_dataset.column_names,
        )
        return formatted_train_dataset, formatted_test_dataset

    def processing_func(self, example, data_dir):
        image_path = os.path.join(data_dir, example["img_name"])
        question = example["question"]
        answer = example["answer"]
        prompt_wo_answer = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
        prompt_wo_answer = self.processor.apply_chat_template(prompt_wo_answer, add_generation_prompt=True)
        prompt_w_answer = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}, {"role": "assistant", "content": [{"type":"text", "text": answer}]}]
        prompt_w_answer = self.processor.apply_chat_template(prompt_w_answer, add_generation_prompt=True)
        output =  {
            "image_path": image_path,
            "question": question,
            "answer": answer,
            "prompt_w_answer": prompt_w_answer,
            "prompt_wo_answer": prompt_wo_answer,
        }
        if "image" in example:
            output["image"] = example["image"]
        return output
    


    def mathvision_processing_func(self, example, image_dir):
        image_path = os.path.join(image_dir, f"{example['index']}.jpg")
        question = example['question']
        answer = example['answer']
        prompt_wo_answer = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}]
        prompt_wo_answer = self.processor.apply_chat_template(prompt_wo_answer, add_generation_prompt=True)
        prompt_w_answer = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}, {"role": "assistant", "content": [{"type":"text", "text": answer}]}]
        prompt_w_answer = self.processor.apply_chat_template(prompt_w_answer, add_generation_prompt=True)
        return {
            "image_path": image_path,
            "question": question,
            "answer": answer,
            "prompt_w_answer": prompt_w_answer,
            "prompt_wo_answer": prompt_wo_answer,
        }

    def mmbench_processing_func(self, example, image_dir):
        image_path = os.path.join(image_dir, f"{example['index']}.jpg")
        question = example['question']
        answer = example['answer']
        options = {
            cand: example[cand]
            for cand in string.ascii_uppercase
            if cand in example and not pd.isna(example[cand])
        }
        options_prompt = 'Options:\n'
        for key, item in options.items():
            options_prompt += f'{key}. {item}\n'
        hint = example['hint'] if ('hint' in example and not pd.isna(example['hint'])) else None
        prompt = ''
        if hint is not None:
            prompt += f'Hint: {hint}\n'
        prompt += f'Question: {question}\n'
        if len(options):
            prompt += options_prompt
            prompt += 'Please select the correct answer from the options above. \n'
        prompt_wo_answer = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
        prompt_wo_answer = self.processor.apply_chat_template(prompt_wo_answer, add_generation_prompt=True)
        prompt_w_answer = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}, {"role": "assistant", "content": [{"type":"text", "text": answer}]}]
        prompt_w_answer = self.processor.apply_chat_template(prompt_w_answer, add_generation_prompt=True)
        return {
            "image_path": image_path,
            "question": question,
            "answer": answer,
            "prompt_w_answer": prompt_w_answer,
            "prompt_wo_answer": prompt_wo_answer,
        }
    


class data_collactor:
    def __init__(self, processor, max_length=None):
        self.processor = processor
        self.max_length = max_length
        self.ignore_index = -100

    def __call__(self, batch):
        image_path = [sample["image_path"] for sample in batch]
        question = [sample["question"] for sample in batch]
        answer = [sample["answer"] for sample in batch]
        prompt_w_answer = [sample["prompt_w_answer"] for sample in batch]
        prompt_wo_answer = [sample["prompt_wo_answer"] for sample in batch]
        image_list = []
        
        if "image" in batch[0]:
            image_list = [sample["image"] for sample in batch]
        else:
            for i in range(len(prompt_w_answer)):
                image_list.append(Image.open(image_path[i]).convert("RGB"))
        training_inputs = self.processor(
            images=image_list,
            text=prompt_w_answer,
            padding=True,
            # padding="max_length",
            # max_length=3000,
            # truncation=True,
            return_tensors="pt",
            padding_side="right",
        )
        prompt_wo_answer_tokens = self.processor.tokenizer(
                prompt_wo_answer, 
                padding=False,
                padding_side="right",
            ).input_ids
        label_tokens = copy.deepcopy(training_inputs.input_ids)
        for i in range(len(label_tokens)):
            label_tokens[i][
                : len(prompt_wo_answer_tokens[i])
                + sum(training_inputs.input_ids[i] == self.processor.tokenizer.image_token_id) - 1
            ] = self.ignore_index
        training_inputs["labels"] = label_tokens
        
        return training_inputs
    

class eval_data_collactor:
    def __init__(self, processor, max_length=None):
        self.processor = processor
        self.max_length = max_length
        self.ignore_index = -100

    def __call__(self, batch):
        image_path = [sample["image_path"] for sample in batch]
        question = [sample["question"] for sample in batch]
        answer = [sample["answer"] for sample in batch]
        prompt_w_answer = [sample["prompt_w_answer"] for sample in batch]
        prompt_wo_answer = [sample["prompt_wo_answer"] for sample in batch]
        image_list = []

        if "image" in batch[0]:
            image_list = [sample["image"] for sample in batch]
        else:
            for i in range(len(prompt_w_answer)):
                image = Image.open(image_path[i]).convert("RGB")
                image_length, image_width = image.size
                if image_length < 334 or image_width < 334:
                    image = image.resize((512, 512))
                image_list.append(image)

        gen_inputs = self.processor(
            text=prompt_wo_answer,
            images=image_list,
            padding="longest" if self.max_length is None else "max_length",
            max_length=self.max_length,
            truncation=self.max_length not in [None, -1],
            return_tensors="pt",
            padding_side="left"
        )
        return{
            "questions": question,
            "answers": answer,
            "image_path": image_path,
            "prompt_wo_answer_tokens": gen_inputs.input_ids,
            "gen_inputs": gen_inputs
        }


def compute_bleu1_score(prediction, reference):
    prediction_tokens = stem_words(tokenize_text(prediction)).split(' ')
    reference_tokens = stem_words(tokenize_text(reference)).split(' ')
    smooth_fn = SmoothingFunction().method1
    bleu_score = sentence_bleu([reference_tokens], prediction_tokens, smoothing_function=smooth_fn, weights=(1, 0, 0, 0))
    return bleu_score


def is_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))


def tokenize_text(text):
    if is_chinese(text):
        return " ".join(jieba.cut(text)) 
    text = text.lower()
    return text  

def stem_words(sentence):
    stemmer = PorterStemmer()
    stemmed = [stemmer.stem(word) for word in sentence.split(' ')]
    return ' '.join(stemmed)

def compute_rouge_score(prediction, reference):
    rouge = Rouge()
    prediction = stem_words(tokenize_text(prediction))
    reference = stem_words(tokenize_text(reference))
    scores = rouge.get_scores([prediction], [reference])[0]
    results = {"rouge1": scores["rouge-1"], "rouge2": scores["rouge-2"], "rougeL": scores["rouge-l"]}
    return results 


def preprocess_logits_for_metrics(batch_output, labels):
    if isinstance(batch_output, tuple):
        logits = batch_output[0]
    return logits.argmax(dim=-1)


class my_trainer(transformers.Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _save_checkpoint(self, model, trial):
    # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
    # want to save except FullyShardedDDP.
    # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

        from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
        from transformers.trainer_callback import ExportableState
        TRAINER_STATE_NAME = "trainer_state.json"
        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)

        ######changed######
        os.makedirs(output_dir, exist_ok=True)
        self.save_model(output_dir, _internal_call=True)
        # if self.args.use_peft:
        #     self.save_model(output_dir, _internal_call=True)
        if self.args.requires_grad_list not in [None, []]:
            for module_name in self.args.requires_grad_list:
                    if hasattr(self, "deepspeed") and self.deepspeed is not None:
                        torch.save(
                            get_attr(model.module, module_name).state_dict(), os.path.join(output_dir, module_name+".pth")
                        )
                    else:
                        torch.save(
                            get_attr(model, module_name).state_dict(), os.path.join(output_dir, module_name+".pth")
                        )
            with open(os.path.join(output_dir, "require_grads_list.json"), "w") as file:
                json.dump(self.args.requires_grad_list, file)
        #####changed#######

        if not self.args.save_only_model:
            # Save optimizer and scheduler
            self._save_optimizer_and_scheduler(output_dir)
            # Save RNG state
            self._save_rng_state(output_dir)

        # Save the Trainer state
        if self.args.should_save:
            # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
            for cb in [
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ]:
                cb_name = cb.__class__.__name__
                cb_state = cb.state()
                if isinstance(self.state.stateful_callbacks[cb_name], list):
                    self.state.stateful_callbacks[cb_name].append(cb_state)
                else:
                    self.state.stateful_callbacks[cb_name] = cb_state
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            # Solely rely on numerical checkpoint id for rotation.
            # mtime is not reliable especially on some fuse fs in cloud environments.
            self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        with torch.no_grad():
            self.model.eval()
            eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
            dataloader = self.get_eval_dataloader(eval_dataset)
            decoded_gen_preds = []
            decoded_gen_labels = []
            decoded_questions = []
            decoded_image_paths = []
            output_json = {}
            
            for tokens in tqdm(dataloader, desc="Evaluation"):
                # outputs = self.model(**batch)
                # label_mask = batch["labels"] != -100
                # output_mask = torch.roll(batch["labels"] != -100, shifts=-1, dims=-1)
                # batch_decoded_preds = self.tokenizer.batch_decode([outputs.logits.argmax(dim=-1)[i][output_mask[i]] for i in range(len(output_mask))], skip_special_tokens=True)
                # batch_decoded_labels = self.tokenizer.batch_decode([batch["labels"][i][label_mask[i]] for i in range(len(label_mask))], skip_special_tokens=True)

                prompt_wo_answer_tokens = tokens["prompt_wo_answer_tokens"]
                answers = tokens["answers"]
                questions = tokens["questions"]
                gen_inputs = tokens["gen_inputs"]
                image_path = tokens["image_path"]
                # gen_outputs = self.model(**gen_inputs)
                # output_ids = torch.argmax(gen_outputs.logits, dim=-1)
                # otuput = self.processing_class.batch_decode(output_ids[..., -2:-1], skip_special_tokens=True)
                gen_outputs = self.model.generate(**gen_inputs, pad_token_id=self.processing_class.eos_token_id, max_new_tokens=100)
                batch_decoded_gen_preds = self.processing_class.batch_decode(gen_outputs, skip_special_tokens=True)
                batch_decoded_prompt_wo_answer = self.processing_class.batch_decode(prompt_wo_answer_tokens, skip_special_tokens=True)
                batch_decoded_gen_preds = [batch_decoded_gen_preds[i].replace(batch_decoded_prompt_wo_answer[i], "") for i in range(len(batch_decoded_gen_preds))]
                decoded_gen_preds.extend(batch_decoded_gen_preds)
                decoded_gen_labels.extend(answers)
                decoded_questions.extend(questions)
                decoded_image_paths.extend(image_path)
        bertscore_metric = evaluate.load("bertscore")
        aggregated_bleu = {"bleu-1": 0}
        aggregated_rouge = {"rouge1": {"f": 0, "p": 0, "r": 0}, "rouge2": {"f": 0, "p": 0, "r": 0}, "rougeL": {"f": 0, "p": 0, "r": 0}}
        aggregated_bertscore = {"precision": 0, "recall": 0, "f1": 0}
        if torch.distributed.is_initialized():
            local_rank = torch.distributed.get_rank()
            with open(os.path.join(self.args.output_dir, f"temp_eval_results_{local_rank}.json"), "w") as file:
                json.dump([decoded_gen_preds, decoded_gen_labels, decoded_questions, decoded_image_paths], file, ensure_ascii=False)
            torch.distributed.barrier()
            decoded_gen_preds = []
            decoded_gen_labels = []
            decoded_questions = []
            decoded_image_paths = []
            for i in range(torch.distributed.get_world_size()):
                with open(os.path.join(self.args.output_dir, f"temp_eval_results_{i}.json"), "r") as file:
                    json_data = json.load(file)
                decoded_gen_preds.extend(json_data[0])
                decoded_gen_labels.extend(json_data[1])
                decoded_questions.extend(json_data[2])
                decoded_image_paths.extend(json_data[3])
            torch.distributed.barrier()
            if local_rank == 0:
                for i in range(torch.distributed.get_world_size()):
                    os.remove(os.path.join(self.args.output_dir, f"temp_eval_results_{i}.json"))
        for i, (decoded_image_path, decoded_question, decoded_gen_pred, decoded_gen_label) in enumerate(zip(decoded_image_paths, decoded_questions, decoded_gen_preds, decoded_gen_labels)):
            decoded_gen_pred = decoded_gen_pred.strip()
            decoded_gen_label = decoded_gen_label.strip()
            processed_pred = stem_words(tokenize_text(decoded_gen_pred))
            processed_label = stem_words(tokenize_text(decoded_gen_label))
            output_json[i] = {"image_path": decoded_image_path, "question": decoded_question, "pred": decoded_gen_pred, "label": decoded_gen_label, "processed_pred": processed_pred, "processed_label": processed_label}
        with open(os.path.join(self.args.output_dir, "eval_results.json"), "w") as file:
            json.dump(output_json, file, ensure_ascii=False)
        for i, (decoded_question, decoded_gen_pred, decoded_gen_label, decoded_image_path) in enumerate(zip(decoded_questions, decoded_gen_preds, decoded_gen_labels, decoded_image_paths)):
            decoded_gen_pred = decoded_gen_pred.strip()
            decoded_gen_label = decoded_gen_label.strip()
            processed_pred = stem_words(tokenize_text(decoded_gen_pred))
            processed_label = stem_words(tokenize_text(decoded_gen_label))

            try:
                bleu_results = compute_bleu1_score(prediction=decoded_gen_pred, reference=decoded_gen_label)
                aggregated_bleu["bleu-1"] += bleu_results
                output_json[i]["bleu"] = bleu_results
            except:
                output_json[i]["bleu"] = "caculation error"
            try:
                rouge_results = compute_rouge_score(prediction=decoded_gen_pred, reference=decoded_gen_label)
                for key1 in aggregated_rouge:
                    for key2 in aggregated_rouge[key1]:
                        aggregated_rouge[key1][key2] += rouge_results[key1][key2]
                output_json[i]["rouge"] = rouge_results
            except:
                output_json[i]["rouge"] = "caculation error"
            try:
                bertscore_results = bertscore_metric.compute(predictions=[decoded_gen_pred], references=[decoded_gen_label], lang="en")
                bertscore_results = bertscore_metric.compute(predictions=[decoded_gen_pred], references=[decoded_gen_label], lang="en")
                bertscore_results["precision"] = bertscore_results["precision"][0]
                bertscore_results["recall"] = bertscore_results["recall"][0]
                bertscore_results["f1"] = bertscore_results["f1"][0]
                bertscore_results.pop("hashcode", None)
                aggregated_bertscore["precision"] += bertscore_results["precision"]
                aggregated_bertscore["recall"] += bertscore_results["recall"]
                aggregated_bertscore["f1"] += bertscore_results["f1"]
                output_json[i]["bertscore"] = bertscore_results
            except:
                output_json[i]["bertscore"] = "caculation error"

            output_json[i] = {**{"question": decoded_question, "image_path": decoded_image_path, "pred": decoded_gen_pred, "label": decoded_gen_label, "processed_pred": processed_pred, "processed_label": processed_label}, **output_json[i]}
            

        aggregated_bleu["bleu-1"] /= len(decoded_gen_preds)
        for key1 in aggregated_rouge:
                for key2 in aggregated_rouge[key1]:
                    aggregated_rouge[key1][key2] /= len(decoded_gen_preds)
        aggregated_bertscore["precision"] /= len(decoded_gen_preds)
        aggregated_bertscore["recall"] /= len(decoded_gen_preds)
        aggregated_bertscore["f1"] /= len(decoded_gen_preds)
        output_json = {"BLEU": aggregated_bleu, "ROUGE": aggregated_rouge, "bertscore": aggregated_bertscore, **output_json}
        with open(os.path.join(self.args.output_dir, "eval_results.json"), "w") as file:
            json.dump(output_json, file, ensure_ascii=False)

        return {
            "bleu": aggregated_bleu,
            "rouge": aggregated_rouge,
            "bertscore": aggregated_bertscore,
        }
    

class compute_metrics:
    def __init__(self, processor, output_dir):
        self.processor = processor
        self.output_dir = output_dir
        self.ignore_index = -100

    def __call__(self, eval_pred):
        json_dict = {}
        predictions, labels = eval_pred
        masked_labels = []
        masked_preds = []
        for i in len(predictions):
            masked_preds.append(predictions[i][labels[i] != self.ignore_index][:-1].argmax(dim=-1))
            masked_labels.append(labels[i][labels[i]!= self.ignore_index][1:].argmax(dim=-1))
        recall = self.compute_recall(masked_preds, masked_labels)
        return{"recall":recall}

    def compute_recall(self, predictions, labels):
        recall_list = []
        for prediction, label in zip(predictions, labels):
            true_positives = len(set(label) & set(prediction))
            total_relevant = len(label)
            recall = true_positives / total_relevant if total_relevant > 0 else 0
            recall_list.append(recall)
        return sum(recall_list) / len(recall_list)


def get_attr(obj, attr_str):
    for atrr in attr_str.split("."):
        obj = getattr(obj, atrr)
    return obj

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_dict(OmegaConf.to_container(config.transformers_args))
    set_seed(training_args.seed)
    if model_args.peft_model_path is not None:
        model_args.peft_model_path = model_args.peft_model_path.strip()
    processor = LlavaNextProcessor.from_pretrained(model_args.vlm_model_path)
    train_dataset, eval_dataset = build_dataset(data_args, processor).get_dataset()
    # if hasattr(training_args, "deepspeed") and training_args.deepspeed is not None and "zero3" in training_args.deepspeed:
    #     deepspeed.ops.op_builder.CPUAdamBuilder().load()
    if torch.distributed.is_initialized():
        start_timestamp = [datetime.now().strftime(f"%Y%m%d-%H:%M:%S")]
        torch.distributed.broadcast_object_list(
            start_timestamp, src=0
        )
        start_timestamp = start_timestamp[0]
        print("timestamp is unified.")
    else:
        start_timestamp = datetime.now().strftime(f"%Y%m%d-%H:%M:%S")
    if training_args.do_train:
        my_lora_str = "-my_lora-" if training_args.my_lora else "-"
        training_args.output_dir = os.path.join(
                training_args.output_dir,
                training_args.exp_name + "-" + str(data_args.data_dir) + "-" + f"rank{training_args.lora_rank}" + "-" + str(training_args.lora_target_modules).replace("\\", "") + "-" + str(training_args.requires_grad_list) + my_lora_str + start_timestamp,
            )
    else:
        if model_args.peft_model_path is not None:
            training_args.output_dir = os.path.join(
                    os.path.dirname(model_args.peft_model_path), 
                    training_args.exp_name + "-" + start_timestamp,
                )
        else:
            training_args.output_dir = os.path.join(
                    training_args.output_dir, 
                    training_args.exp_name + "-" + str(data_args.data_dir) + "-" + start_timestamp,
                )
    os.makedirs(training_args.output_dir, exist_ok=True)
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            set_dual_output(os.path.join(training_args.output_dir, "output.txt"))
    else:
        set_dual_output(os.path.join(training_args.output_dir, "output.txt")) 
    training_args.logging_dir = os.path.join(training_args.output_dir, "log")

    with open(
        os.path.join(training_args.output_dir, "all_args_setting.json"), "w"
    ) as file:
        args_dict = {
            "model_args": vars(model_args),
            "data_args": vars(data_args),
            "training_args": training_args.to_dict(),
        }
        json.dump(args_dict, file)
    # args_dict = {
    #     "model_args": vars(model_args),
    #     "data_args": vars(data_args),
    #     "training_args": training_args.to_dict(),
    # }
    # OmegaConf.save(args_dict, os.path.join(training_args.output_dir, "all_args_setting.yaml"))
    os.makedirs(os.path.join(training_args.output_dir, "training_files"), exist_ok=True)
    shutil.copy2(os.path.join(os.path.dirname(__file__), "train.py"), os.path.join(training_args.output_dir, "training_files", "train.py"))
    shutil.copy2(os.path.join(os.path.dirname(__file__), "arguments.py"), os.path.join(training_args.output_dir, "training_files", "arguments.py"))
    shutil.copy2(config_path, os.path.join(training_args.output_dir, "training_files", "config.yaml"))
    for file in os.listdir(os.path.join(os.path.dirname(__file__), "peft/tuners/")):
        if file.endswith(".py"):
            shutil.copy2(os.path.join(os.path.dirname(__file__), "peft/tuners/", file), os.path.join(training_args.output_dir, "training_files", file))

    if model_args.hook_model == True:
        model = HookModel.from_pretrained(
            model_args.vlm_model_path,
            # torch_dtype=torch.float16,
        )
        hidden_size = model.config.text_config.hidden_size
        vocab_size = model.config.text_config.vocab_size
        model.init_mlp([hidden_size, hidden_size//5, vocab_size])
    else:
        model = LlavaNextForConditionalGeneration.from_pretrained(
            model_args.vlm_model_path,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        )
    requires_grad_list = training_args.requires_grad_list
    if model_args.peft_model_path is not None and "require_grads_list.json" in os.listdir(model_args.peft_model_path):
        with open(
            os.path.join(model_args.peft_model_path, "require_grads_list.json"), "r"
        ) as file:
            load_module_list = json.load(file)
            for module_name in load_module_list:
                get_attr(model, module_name).load_state_dict(
                    torch.load(
                        os.path.join(model_args.peft_model_path, module_name+".pth"), 
                        weights_only=True,
                    )
                )
                get_attr(model, module_name).requires_grad_(False)

    if model_args.quant_type not in [None, -1] and training_args.use_peft is True and training_args.do_train:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=training_args.gradient_checkpointing
        )
        
    if training_args.use_peft is True and model_args.peft_model_path is None and training_args.do_train:
        lora_target_modules = training_args.lora_target_modules
        if training_args.my_lora:
            lora_cluster = {"ablation-masked-lora": peft.AblationMaskedLoraConfig, "ablation-neighbor-lora": peft.AblationNeighborLoraConfig, "ablation-div-A-neighbor-lora": peft.AblationDivANeighborLoraConfig, "my_lora": LoraConfig}
            if training_args.peft_type == "hydra-lora":
                print("peft type: hydra-lora")
                peft_config = peft.HydraLoraConfig(
                    r=training_args.lora_rank,
                    lora_alpha=training_args.lora_rank,
                    target_modules=lora_target_modules,
                    lora_nums=2,
                    lora_dropout=0,
                    task_type="CAUSAL_LM",
                )
            elif training_args.peft_type == "r-lora":
                print("peft type: r-lora")
                peft_config = peft.RLoraConfig(
                    r=training_args.lora_rank,
                    lora_alpha=training_args.lora_rank,
                    target_modules=lora_target_modules,
                    lora_nums=2,
                    lora_dropout=0,
                    task_type="CAUSAL_LM",
                )
            elif training_args.peft_type in lora_cluster:
                print(f"peft type: {training_args.peft_type}")
                peft_config = lora_cluster[training_args.peft_type](
                    r=training_args.lora_rank,
                    lora_alpha=training_args.lora_rank,
                    target_modules=lora_target_modules,
                    lora_dropout=0,
                    task_type="CAUSAL_LM",
                    neighbor_gap=training_args.neighbor_gap,
                    r_scaling=training_args.r_scaling,
                    householder_dim=training_args.householder_dim,
                    rotation_angle=training_args.rotation_angle,
                ) if training_args.peft_type == "my_lora" else lora_cluster[training_args.peft_type](
                    r=training_args.lora_rank,
                    target_modules=lora_target_modules,
                    lora_alpha=training_args.lora_rank,
                    lora_dropout=0,
                    task_type="CAUSAL_LM",
                )
            else:
                raise NotImplementedError(f"can not find peft type {training_args.peft_type}")
        else:
            print(f"peft type: {training_args.peft_type}")
            lora_cluster = {"lora": peft.LoraConfig, "ada-lora": peft.AdaLoraConfig, "vb-lora": peft.VBLoRAConfig, "vera": peft.VeraConfig, "dora": peft.LoraConfig}
            p_tuning_cluster = {"prefix-tuning": peft.PrefixTuningConfig, "prompt-tuning": peft.PromptTuningConfig}
            if training_args.peft_type in lora_cluster:
                if training_args.peft_type == "dora":
                    peft_config = lora_cluster[training_args.peft_type](
                        r=training_args.lora_rank,
                        target_modules=lora_target_modules,
                        task_type="CAUSAL_LM",
                        use_dora=True,
                    )
                elif training_args.peft_type == "ada-lora":
                    peft_config = lora_cluster[training_args.peft_type](
                        init_r=training_args.lora_rank,
                        target_r=training_args.lora_rank,
                        total_step=training_args.num_train_epochs * len(train_dataset) // training_args.per_device_train_batch_size,
                        target_modules=lora_target_modules,
                        task_type="CAUSAL_LM",
                    )
                else:
                    peft_config = lora_cluster[training_args.peft_type](
                        r=training_args.lora_rank,
                        target_modules=lora_target_modules,
                        task_type="CAUSAL_LM",
                    )
            elif training_args.peft_type in p_tuning_cluster:
                peft_config = p_tuning_cluster[training_args.peft_type](
                    task_type="CAUSAL_LM",
                    inference_mode=False,
                    num_virtual_tokens=training_args.p_tuning_token,
                    num_layers=model.language_model.config.num_hidden_layers,
                    token_dim=model.language_model.config.hidden_size,
                    num_attention_heads=model.language_model.config.num_attention_heads,
                )
            elif training_args.peft_type == "ia3":
                peft_config = peft.IA3Config(
                    task_type="CAUSAL_LM",  
                    target_modules=lora_target_modules,  
                    feedforward_modules = ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]
                )
            else:
                raise NotImplementedError(f"can not find peft type {training_args.peft_type}")
        model = get_peft_model(model, peft_config)
        if training_args.peft_type == "r-lora" and training_args.my_lora:
            from reinit import reinit_lora
            print("Reinit lora layer")
            reinit_lora(model, model_args)
    elif training_args.use_peft is True and model_args.peft_model_path is not None:
        model = PeftModel.from_pretrained(
            model,
            model_args.peft_model_path,
            torch_dtype=model.dtype,
            is_trainable=True if training_args.do_train else False,
        )
    elif training_args.use_peft is False and model_args.peft_model_path is not None:
        model = LlavaNextForConditionalGeneration.from_pretrained(model_args.peft_model_path)
    if requires_grad_list not in [None, []] and training_args.do_train:
        for module_name in requires_grad_list:
            get_attr(model, module_name).requires_grad_(True)

    set_sft_trainable_parameters(model, training_args)
    if training_args.my_lora and training_args.use_peft:
        print(calc_submodule_trainable_parameters(model.base_model.model))
    elif training_args.use_peft:
        if training_args.peft_type in ["prefix-tuning", "prompt-tuning"]:
            print(calc_submodule_trainable_parameters(model.base_model))
        else:
            print(calc_submodule_trainable_parameters(model))
    else:
        print(calc_submodule_trainable_parameters(model))
    # estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=4, num_nodes=1)
    model.cuda()
    model.to(model.dtype)
    trainer = my_trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=processor.tokenizer,
        data_collator=data_collactor(processor),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        compute_metrics=compute_metrics(processor, training_args.output_dir),
    )
    # optimizer = trainer.create_optimizer()
    # opt_param_ids = {id(p) for group in optimizer.param_groups for p in group['params']}

    # for name, param in model.named_parameters():
    #     if id(param) in opt_param_ids:
    #         print(f"[OPTIMIZED] {name}")
    #         state = optimizer.state[param]
    #         print(f"  exp_avg mean:     {state['exp_avg'].mean().item():.4e}")
    #         print(f"  exp_avg_sq mean:  {state['exp_avg_sq'].mean().item():.4e}")
    #     else:
    #         print(f"[SKIPPED]   {name} ❗❗❗ Not in optimizer")
    if training_args.do_train:
        trainer.train()
        peak_allocated = torch.cuda.max_memory_allocated(device=None) / 1024**2
        print(f"peak training GPU memory: {peak_allocated:.2f}MB")
    if training_args.do_eval:
        trainer.data_collator = eval_data_collactor(processor)
        trainer.evaluate()


if __name__ == "__main__":
#     with torch.profiler.profile(
#     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
#     profile_memory=True,
#     record_shapes=True,
# ) as prof:
    train()
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    # print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

