# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import os
import json
import copy
import logging
import pathlib
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
from sklearn.metrics import roc_auc_score

import torch
import transformers
import numpy as np
from functools import partial

from llava.constants import IGNORE_INDEX, DEFAULT_POINT_TOKEN, DEFAULT_PT_START_TOKEN, \
    DEFAULT_PT_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer

from llava.model import *
from llava import conversation as conversation_lib
from llava.mm_utils import tokenizer_point_token, process_pts, load_pts, occlusion, rotation

from llava.train.urdf_dataset import ReasonSegDataset
from llava.model.language_model.urdf_llm import LISAForCausalLM

local_rank = None


def rank0_print(*args):
    if local_rank == 0:
        print(*args)


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    version: Optional[str] = field(default="v0")
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None)
    vision_tower_path: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(default=-1)  # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='linear')
    mm_use_pt_start_end: bool = field(default=False)
    mm_use_pt_patch_token: bool = field(default=True)
    mm_vision_select_feature: Optional[str] = field(default="patch")
    prompt_token_num: int = field(default=1)
    with_ape: bool = field(default=True)
    with_local: bool = field(default=True)
    with_global: bool = field(default=True)
    with_color: bool = field(default=True)


@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = True
    eval_data_path: Optional[str] = field(default=None, metadata={"help": "Path to the evaluation data."}) 
    point_folder: Optional[str] = field(default=None)
    sample_points_num: int = field(default=4096)
    occlusion: bool = field(default=False)
    predict_type: str = field(default="seg")


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    freeze_mm_mlp_adapter: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
                "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"
    mm_projector_lr: Optional[float] = None
    group_by_modality_length: bool = field(default=False)
    auto_resume:bool=field(default=True)
    resume:str = ""
    debug:bool=field(default=False)

@dataclass
class DataCollatorForLISA(object):
    tokenizer: transformers.PreTrainedTokenizer
    data_args: DataArguments 
    model_args: ModelArguments 

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, any]:

        raw_conversations_batch = [copy.deepcopy(instance["conversations"]) for instance in instances]
        
        sources_for_preprocess = raw_conversations_batch
        
        for source_list in sources_for_preprocess:
            for sentence in source_list:
                if DEFAULT_POINT_TOKEN in sentence['value']:
                    current_val = sentence['value'].replace(DEFAULT_POINT_TOKEN, '').strip()
                    sentence['value'] = DEFAULT_POINT_TOKEN + '\n' + current_val
                    sentence['value'] = sentence['value'].strip()
                    
                    replace_token = DEFAULT_POINT_TOKEN
                    if self.model_args.mm_use_pt_start_end: 
                        replace_token = DEFAULT_PT_START_TOKEN + replace_token + DEFAULT_PT_END_TOKEN
                    sentence["value"] = sentence["value"].replace(DEFAULT_POINT_TOKEN, replace_token)


        all_has_point = [any(DEFAULT_POINT_TOKEN in turn['value'] for turn in source) for source in sources_for_preprocess]
        
        input_ids_list = []
        labels_list = []

        for i, source_conv_list in enumerate(sources_for_preprocess):
            data_dict_single = preprocess_v1( 
                [source_conv_list], 
                tokenizer=self.tokenizer,
                has_point=all_has_point[i]
            )
            input_ids_list.append(data_dict_single['input_ids'][0]) 
            labels_list.append(data_dict_single['labels'][0])

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids_list,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels_list,
            batch_first=True,
            padding_value=IGNORE_INDEX)

        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )

        if 'points' in instances[0] and instances[0]['points'] is not None:
            points_list = [instance['points'] for instance in instances]
            batch['points'] = torch.stack(points_list)
        
        if 'rgb' in instances[0] and instances[0]['rgb'] is not None:
            rgb_list = [instance['rgb'] for instance in instances]
            batch['rgb'] = torch.stack(rgb_list)

        if 'seg_label' in instances[0] and instances[0]['seg_label'] is not None:
            seg_labels = [instance['seg_label'] for instance in instances]
            try:
                batch['seg_label'] = torch.stack(seg_labels)
            except: 
                batch['seg_label'] = seg_labels

        if 'logist_label' in instances[0] and instances[0]['logist_label'] is not None:
            logist_labels = [instance['logist_label'] for instance in instances]
            try:
                batch['logist_label'] = torch.stack(logist_labels)
            except:
                batch['logist_label'] = logist_labels
        
        if 'json_path' in instances[0]:
            batch['json_path'] = [instance['json_path'] for instance in instances]

        return batch

def maybe_zero_3(param, ignore_status=False, name=None):
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
    if hasattr(param, "ds_id"):
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            if not ignore_status:
                logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param

def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
    return to_return


def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
    to_return = {k: t for k, t in named_params if "lora_" not in k}
    if require_grad_only:
        to_return = {k: t for k, t in to_return.items() if t.requires_grad}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return


def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
    to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
    return to_return

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
                                   output_dir: str):

    if getattr(trainer.args, "tune_mm_mlp_adapter", False):
        keys_to_match = ['mm_projector']
        if getattr(trainer.args, "use_pt_start_end", False):
            keys_to_match.extend(['embed_tokens', 'embed_in'])

        weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
        trainer.model.config.save_pretrained(output_dir)

        current_folder = output_dir.split('/')[-1]
        parent_folder = os.path.dirname(output_dir)
        if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
            if current_folder.startswith('checkpoint-'):
                mm_projector_folder = os.path.join(parent_folder, "mm_projector")
                os.makedirs(mm_projector_folder, exist_ok=True)
                torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
            else:
                torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
        return

    if trainer.deepspeed:
        torch.cuda.synchronize()
        trainer.save_model(output_dir)
        return

    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {
            key: value.cpu()
            for key, value in state_dict.items()
        }
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict) 


def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def _tokenize_fn(strings: Sequence[str],
                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ) for text in strings
    ]
    input_ids = labels = [
        tokenized.input_ids[0] for tokenized in tokenized_list
    ]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def _mask_targets(target, tokenized_lens, speakers):
    cur_idx = tokenized_lens[0]
    tokenized_lens = tokenized_lens[1:]
    target[:cur_idx] = IGNORE_INDEX
    for tokenized_len, speaker in zip(tokenized_lens, speakers):
        if speaker == "human":
            target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
        cur_idx += tokenized_len


def _add_speaker_and_signal(header, source, get_conversation=True):
    BEGIN_SIGNAL = "### "
    END_SIGNAL = "\n"
    conversation = header
    for sentence in source:
        from_str = sentence["from"]
        if from_str.lower() == "human":
            from_str = conversation_lib.default_conversation.roles[0]
        elif from_str.lower() == "gpt":
            from_str = conversation_lib.default_conversation.roles[1]
        else:
            from_str = 'unknown'
        sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
                             sentence["value"] + END_SIGNAL)
        if get_conversation:
            conversation += sentence["value"]
    conversation += BEGIN_SIGNAL
    return conversation


def preprocess_multimodal(
        sources: Sequence[str],
        data_args: DataArguments
) -> Dict:
    is_multimodal = data_args.is_multimodal
    if not is_multimodal:
        return sources

    for source in sources:
        for sentence in source:
            if DEFAULT_POINT_TOKEN in sentence['value']:
                sentence['value'] = sentence['value'].replace(DEFAULT_POINT_TOKEN, '').strip()
                sentence['value'] = DEFAULT_POINT_TOKEN + '\n' + sentence['value']
                sentence['value'] = sentence['value'].strip()
            replace_token = DEFAULT_POINT_TOKEN
            if data_args.mm_use_pt_start_end:
                replace_token = DEFAULT_PT_START_TOKEN + replace_token + DEFAULT_PT_END_TOKEN
            sentence["value"] = sentence["value"].replace(DEFAULT_POINT_TOKEN, replace_token)

    return sources


def preprocess_v1(
        sources,
        tokenizer: transformers.PreTrainedTokenizer,
        has_point: bool = False
) -> Dict:
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    if has_point:
        input_ids = torch.stack(
            [tokenizer_point_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids

    targets = input_ids.clone()

    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO

    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_INDEX
        for i, _round in enumerate(rounds):
            if _round == "":
                break

            parts = _round.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_point:
                round_len = len(tokenizer_point_token(_round, tokenizer))
                instruction_len = len(tokenizer_point_token(parts[0], tokenizer)) - 2
            else:
                round_len = len(tokenizer(_round).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            target[cur_len: cur_len + instruction_len] = IGNORE_INDEX

            cur_len += round_len
        target[cur_len:] = IGNORE_INDEX

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                print(sources)
                target[:] = IGNORE_INDEX
                print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )

    return dict(
        input_ids=input_ids,
        labels=targets,
    )


def preprocess_plain(
        sources: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    conversations = []
    for source in sources:
        assert len(source) == 2
        assert DEFAULT_POINT_TOKEN in source[0]['value']
        source[0]['value'] = DEFAULT_POINT_TOKEN
        conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
        conversations.append(conversation)
    input_ids = [tokenizer_point_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
    targets = copy.deepcopy(input_ids)
    for target, source in zip(targets, sources):
        tokenized_len = len(tokenizer_point_token(source[0]['value'], tokenizer))
        target[:tokenized_len] = IGNORE_INDEX

    return dict(input_ids=input_ids, labels=targets)


def preprocess(
        sources: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer,
        has_point: bool = False
) -> Dict:
    if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
        return preprocess_plain(sources, tokenizer)
    if conversation_lib.default_conversation.version.startswith("v1"):
        return preprocess_v1(sources, tokenizer, has_point=has_point)

def KLD(map1, map2, eps = 1e-12):
    map1, map2 = map1/(map1.sum()+eps), map2/(map2.sum() + eps)
    kld = np.sum(map2*np.log( map2/(map1+eps) + eps))
    return kld
    
def SIM(map1, map2, eps=1e-12):
    map1, map2 = map1/(map1.sum()+eps), map2/(map2.sum() + eps)
    intersection = np.minimum(map1, map2)
    return np.sum(intersection)

def evaluate(seg_pred, seg_gt):
    seg_pred = seg_pred.cpu().detach().numpy()
    seg_gt = seg_gt.cpu().detach().numpy()

    AUC_seg = np.zeros((seg_gt.shape[0], seg_gt.shape[2]))
    IOU_seg = np.zeros((seg_gt.shape[0], seg_gt.shape[2]))

    SIM_matrix = np.zeros(seg_gt.shape[0])

    IOU_thres = np.linspace(0, 1, 20)
    num = seg_gt.shape[0]

    for b in range(num):
        SIM_matrix[b] = SIM(seg_pred[b], seg_gt[b])

        seg_t_true = (seg_gt[b] >= 0.5).astype(int)
        seg_p_score = seg_pred[b]

        if np.sum(seg_t_true) == 0:
            AUC_seg[b] = np.nan
            IOU_seg[b] = np.nan
        else:
            try:
                auc_seg = roc_auc_score(seg_t_true, seg_p_score)
                AUC_seg[b] = auc_seg
            except ValueError:
                AUC_seg[b] = np.nan

            temp_iou = []
            for thre in IOU_thres:
                p_mask = (seg_p_score >= thre).astype(int)
                intersect = np.sum(p_mask & seg_t_true)
                union = np.sum(p_mask | seg_t_true)
                temp_iou.append(1.*intersect/union)
            temp_iou = np.array(temp_iou)
            aiou = np.mean(temp_iou)
            IOU_seg[b] = aiou

    AUC_ = np.nanmean(AUC_seg)
    IOU_ = np.nanmean(IOU_seg)
    SIM_ = np.mean(SIM_matrix)

    return AUC_, IOU_, SIM_

def train():
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    if training_args.debug: 
        import debugpy
        debugpy.listen(("localhost", 5678))
        rank0_print("Waiting for debugger attach...")
        debugpy.wait_for_client()
    
    data_args.with_color = model_args.with_color
    local_rank = training_args.local_rank
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    training_args.output_dir = os.path.join(training_args.output_dir, data_args.predict_type)
    rank0_print(f"Save model path : {training_args.output_dir}")

    bnb_model_from_pretrained_args = {}
    if training_args.bits in [4, 8]:
        from transformers import BitsAndBytesConfig
        bnb_model_from_pretrained_args.update(dict(
            device_map={"": training_args.device},
            load_in_4bit=training_args.bits == 4,
            load_in_8bit=training_args.bits == 8,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=training_args.bits == 4,
                load_in_8bit=training_args.bits == 8,
                llm_int8_threshold=6.0,
                llm_int8_has_fp16_weight=False,
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=training_args.double_quant,
                bnb_4bit_quant_type=training_args.quant_type  # {'fp4', 'nf4'}
            )
        ))

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )

    tokenizer.pad_token = tokenizer.unk_token
    
    num_added_tokens = tokenizer.add_tokens("[SEG]")
    seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
    model_init_kwargs = { 
        "seg_token_idx": seg_token_idx,
        "use_mm_start_end": model_args.mm_use_pt_start_end
    }
    bnb_model_from_pretrained_args.update(model_init_kwargs)

    model = LISAForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        torch_dtype=compute_dtype,
        **bnb_model_from_pretrained_args
    )
    model.config.use_cache = False

    if model_args.freeze_backbone:
        model.model.requires_grad_(False)

    if training_args.gradient_checkpointing:
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

    if training_args.lora_enable:
        from peft import LoraConfig, get_peft_model
        lora_r = training_args.lora_r
        if lora_r > 0:
            def find_linear_layers(model, lora_target_modules):
                cls = torch.nn.Linear
                lora_module_names = set()
                for name, module in model.named_modules():
                    if (
                        isinstance(module, cls)
                        and all(
                            [
                                x not in name
                                for x in [
                                    "visual_model",
                                    "vision_tower",
                                    "mm_projector",
                                    "text_hidden_fcs",
                                ]
                            ]
                        )
                        and any([x in name for x in lora_target_modules])
                    ):
                        lora_module_names.add(name)
                return sorted(list(lora_module_names))
            
        targe_modules = "q_proj,v_proj"
        lora_target_modules = find_linear_layers(
            model, targe_modules.split(",")
        )

        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=lora_target_modules,
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        if training_args.bits == 16:
            if training_args.bf16:
                model.to(torch.bfloat16)
            if training_args.fp16:
                model.to(torch.float16)
        rank0_print("Adding LoRA adapters...")
        model = get_peft_model(model, lora_config)

    if model_args.version in conversation_lib.conv_templates:
        conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
    else:
        conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]

    model.get_model().initialize_vision_modules(
        model_args=model_args,
        fsdp=training_args.fsdp
    )

    vision_tower = model.get_vision_tower()
    vision_tower.load_model()

    data_args.is_multimodal = True

    model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
    if model_args.tune_mm_mlp_adapter:
        model.requires_grad_(False)
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = True

    model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
    if training_args.freeze_mm_mlp_adapter:
        for p in model.get_model().mm_projector.parameters():
            p.requires_grad = False

    if training_args.bits in [4, 8]:
        model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)

    model.config.mm_use_pt_start_end = data_args.mm_use_pt_start_end = model_args.mm_use_pt_start_end
    model.config.mm_projector_lr = training_args.mm_projector_lr
    training_args.use_pt_start_end = model_args.mm_use_pt_start_end
    model.config.mm_use_pt_patch_token = model_args.mm_use_pt_patch_token
    model.config.with_color = model_args.with_color
    model.config.sample_points_num = data_args.sample_points_num
    model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)

    train_dataset = ReasonSegDataset(
        tokenizer=tokenizer,
        run_type="train",
        predict_type=data_args.predict_type,
    )
    eval_dataset = None
    if data_args.eval_data_path:
        eval_dataset = ReasonSegDataset(
            tokenizer=tokenizer,
            run_type="train",
            predict_type=data_args.predict_type,
        )
        rank0_print(f"Using {len(eval_dataset)} examples for evaluation.")
    else:
        rank0_print("No evaluation dataset provided.")

    rank0_print(f"Training with {len(train_dataset)} examples.")
    
    data_collator = DataCollatorForLISA(
        tokenizer=tokenizer, 
        data_args=data_args, 
        model_args=model_args
    )

    trainer = LLaVATrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
    )

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()

    model.config.use_cache = True

    if training_args.lora_enable:
        state_dict = get_peft_state_maybe_zero_3(
            model.named_parameters(), training_args.lora_bias
        )
        non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
            model.named_parameters()
        )
        if training_args.local_rank == 0 or training_args.local_rank == -1:
            model.config.save_pretrained(training_args.output_dir)
            model.save_pretrained(training_args.output_dir, state_dict=state_dict)
            torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
    else:
        safe_save_model_for_hf_trainer(trainer=trainer,
                                       output_dir=training_args.output_dir)

    if training_args.do_eval:
        metrics = trainer.evaluate()
        rank0_print("Evaluation metrics:")
        for key, value in metrics.items():
            rank0_print(f"  {key}: {value}")


if __name__ == "__main__":
    train()
