# WANDB_MODE=disabled PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=16 --time=30000 accelerate launch --config_file scripts/accelerate_configs/single_gpu.yaml src/train_internvl.py --model_name_or_path "/mnt/lustrenew/mllm_safety-shared/models/huggingface/OpenGVLab/InternVL3-8B/"
# WANDB_MODE=disabled PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=16 --time=30000 accelerate launch --config_file scripts/accelerate_configs/single_gpu.yaml src/train_internvl.py --model_name_or_path "/mnt/lustrenew/mllm_safety-shared/models/huggingface/OpenGVLab/InternVL3-1B/"
import sys
import json
import pathlib
from dataclasses import dataclass
from copy import deepcopy
import numpy as np
import torch
from transformers import Trainer
import transformers
import trl
import accelerate

sys.path.append('/mnt/lustrenew/mllm_safety-shared/tmp/lingjie/InternVL/internvl_chat/')
from internvl.train.constants import (
    BOX_END_TOKEN, BOX_START_TOKEN, IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
    IMG_START_TOKEN, QUAD_END_TOKEN, QUAD_START_TOKEN, REF_END_TOKEN, REF_START_TOKEN
)
from internvl.train.dataset import preprocess_internvl2_5, build_transform, dynamic_preprocess
from src import utils


IGNORE_INDEX = -100

@dataclass
class ScriptArguments:
    data_config_path: str = "data/animals/config_image.yaml"
    data_overwrite_args: str = "" # e.g. --data_overwrite_args "data.train[0].images_dirs[0]=/new/path/to/images,..."
    num_proc: int = 8
    mask_prompt: bool = False
    # intern specific
    force_image_size : int = 448
    down_sample_ratio : float = 0.5
    pad2square : bool = False
    conv_style : str = "internvl2_5"
    dynamic_image_size : bool = False
    use_thumbnail : bool = True
    min_dynamic_patch : int = 1
    max_dynamic_patch : int = 12
    normalize_type : str = 'imagenet'
    use_packed_ds : bool = False
    ds_name : str = "v-oocr"
    vision_select_layer : int = -1 # Selected layer for vit


@dataclass
class SFTConfig(trl.SFTConfig):
    output_dir: str = "models/tmp"
    report_to: str = "wandb"
    overwrite_output_dir: bool = True
    seed: int = 42
    per_device_train_batch_size: int = 1
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-5
    lr_scheduler_type: str = "cosine"
    bf16: bool = True
    num_train_epochs: float = 20
    logging_steps: float = 1
    eval_strategy: str = "epoch"
    save_strategy: str = "no" # "epoch"
    save_only_model: bool = True
    eval_on_start: bool = True
    # intern specific
    group_by_length: bool = True


def preprocess(data_item):
    conversation = [
        {'from':'human', 'value': data_item['prompt'].replace(r'{image_prefix}','',1)},
        {'from':'gpt', 'value': data_item['prompt_response'].replace(data_item['prompt'],'',1).strip()}
    ]
    data_item['conversations'] = conversation
    
    transform = build_transform(is_train=True, input_size=script_args.force_image_size, pad2square=False, normalize_type = script_args.normalize_type)
    if "<image>" not in data_item['conversations'][0]['value']:
        data_item['conversations'][0]['value'] = '<image>\n' + data_item['conversations'][0]['value']
    image = data_item['image']
    if script_args.dynamic_image_size:
        images = dynamic_preprocess(image, min_num=script_args.min_dynamic_patch, max_num=script_args.max_dynamic_patch,
                                            image_size=script_args.force_image_size, use_thumbnail=script_args.use_thumbnail)
    else:
        images = [image]
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    num_patches = pixel_values.size(0)

    num_image_token = int((script_args.force_image_size // patch_size) ** 2 * (script_args.down_sample_ratio ** 2))

    ret = preprocess_internvl2_5(script_args.conv_style, [deepcopy(data_item['conversations'])],
                                tokenizer, [num_image_token * num_patches],
                                group_by_length=training_args.group_by_length,
                                use_packed_ds=script_args.use_packed_ds, ds_name=script_args.ds_name)
    
    position_ids = ret['attention_mask'].long().cumsum(-1) - 1
    position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
    ret = dict(
        input_ids=ret['input_ids'][0],
        labels=ret['labels'][0],
        attention_mask=ret['attention_mask'][0],
        position_ids=position_ids[0],
        pixel_values=pixel_values,
        image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
    )
    return ret

def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
    first = features[0]
    batch = {}

    batch_lens = [feat['input_ids'].shape for feat in features]
    max_item_length = max_item_length or max(batch_lens)[0]
    for idx in range(len(features)):
        feat = features[idx]
        temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
        temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
        feat['input_ids'] = temp_input_ids
        temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
        temp_labels[:feat['labels'].shape[0]] = feat['labels']
        feat['labels'] = temp_labels
        feat['attention_mask'] = feat['input_ids'].ne(pad_id)

        if 'position_ids' in feat:
            temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
            temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
            feat['position_ids'] = temp_position_ids

        if 'loss_weight' in feat:
            temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length)
            temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight']
            feat['loss_weight'] = temp_loss_weight

    # Special handling for labels.
    # Ensure that tensor is created with the correct type
    # (it should be automatically the case, but let's make sure of it.)
    if 'label' in first and first['label'] is not None:
        label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
        dtype = torch.long if isinstance(label, int) else torch.float
        batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
    elif 'label_ids' in first and first['label_ids'] is not None:
        if isinstance(first['label_ids'], torch.Tensor):
            batch['labels'] = torch.stack([f['label_ids'] for f in features])
        else:
            dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
            batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)

    # Handling of all other possible keys.
    # Again, we will use the first element to figure out which key/values are not None for this model.
    for k, v in first.items():
        if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
                v is not None and not isinstance(v, str):
            if isinstance(v, torch.Tensor):
                batch[k] = torch.stack([f[k] for f in features])
            elif isinstance(v, np.ndarray):
                batch[k] = torch.tensor(np.stack([f[k] for f in features]))
            else:
                batch[k] = torch.tensor([f[k] for f in features])
        if k in ('pixel_values', 'image_flags'):
            if isinstance(v, torch.Tensor):
                batch[k] = torch.concat([f[k] for f in features])
            elif isinstance(v, np.ndarray):
                batch[k] = torch.concat(np.stack([f[k] for f in features]))
            else:
                batch[k] = torch.concat([f[k] for f in features])
    return batch

if __name__ == "__main__":
    parser = trl.TrlParser((ScriptArguments, SFTConfig, trl.ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    transformers.set_seed(training_args.seed)
    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}
    training_args.use_packed_ds = script_args.use_packed_ds
    ################
    # Model, Tokenizer
    ################
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, add_eos_token=False, trust_remote_code=True
    )
    tokenizer.tokenizer_path = model_args.model_name_or_path
    tokenizer.model_max_length = training_args.max_length
    token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
                  QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
                  REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
    num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
    
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )

    quantization_config = trl.get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=trl.get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        downsample_ratio = script_args.down_sample_ratio
    )
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=True, **model_kwargs
    ) 

    model.img_context_token_id = img_context_token_id
    assert model.downsample_ratio == script_args.down_sample_ratio
    # model.downsample_ratio = script_args.down_sample_ratio
    patch_size = model.config.vision_config.patch_size
    print(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
    if model.config.vision_config.image_size != script_args.force_image_size:
        print(f'Resizing position embedding from '
                    f'{model.config.vision_config.image_size} '
                    f'to {script_args.force_image_size}...')
        model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
                                                 new_size=script_args.force_image_size,
                                                 patch_size=patch_size)
        model.config.vision_config.image_size = script_args.force_image_size
    model.config.force_image_size = script_args.force_image_size
    model.num_image_token = int((script_args.force_image_size // patch_size) ** 2 * (script_args.down_sample_ratio ** 2))
    model.select_layer = script_args.vision_select_layer
    
    if num_new_tokens > 0:
        model.language_model.resize_token_embeddings(len(tokenizer))
        output_embeddings = model.language_model.get_output_embeddings().weight.data
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

        model.config.llm_config.vocab_size = len(tokenizer)
        model.language_model.config.vocab_size = len(tokenizer)
    
    model.language_model.config.use_cache = False
    model.vision_model.gradient_checkpointing = True
    model.vision_model.encoder.gradient_checkpointing = True
    
    ################
    # Create a data collator to encode text and image pairs
    ################

    with accelerate.PartialState().local_main_process_first():
        data_config = utils.parse_data_config(script_args.data_config_path, script_args.data_overwrite_args)
        train_configs, eval_configs = utils.parse_train_and_eval_config(data_config)
        dataset = utils.get_train_dataset(train_configs)

        dataset = dataset.map(preprocess)
        dataset.set_format(
            type='torch',
            columns=['input_ids','labels','attention_mask','position_ids','pixel_values','image_flags']
        )

    class RankEvalCallback(transformers.trainer_callback.TrainerCallback):

        def on_evaluate(
            self, 
            args: transformers.TrainingArguments, 
            state: transformers.trainer_callback.TrainerState, 
            control: transformers.trainer_callback.TrainerControl, 
            **kwargs
        ):
            # from src.eval_fast import eval_rank
            from src.eval.eval_rank_internvl import eval_rank
            model = kwargs["model"]
            eval_mode = data_config["eval_mode"]
            results = {}
            if "rank" in eval_mode:
                results["rank"] = {}

                for eval_idx, eval_config in enumerate(eval_configs):
                    if eval_config is None:
                        continue
                    partial_results = eval_rank(
                        model=model, 
                        tokenizer=tokenizer, 
                        data_config=eval_config,
                        per_device_eval_batch_size=args.per_device_eval_batch_size,
                    )
                    for template_key, template_result in partial_results.items():
                        new_key = f"eval-{eval_idx}.{template_key}"
                        results["rank"][new_key] = template_result
        
            if accelerate.PartialState().is_main_process:
                # Ensure latest eval log contains eval_loss
                latest_log = state.log_history[-1]
                assert "eval_loss" in latest_log

                # Prepare result dict
                results["log_history"] = latest_log

                # Construct save path
                checkpoint_dir = f"checkpoint-{int(latest_log['step'])}"
                results_path = pathlib.Path(args.output_dir) / checkpoint_dir / "eval" / "log.json"
                results_path.parent.mkdir(parents=True, exist_ok=True)

                # Save results
                with open(results_path, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=4)

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=concat_pad_data_collator,
        train_dataset=dataset,
        eval_dataset=dataset,
        tokenizer=tokenizer,
        callbacks=[RankEvalCallback]
    )

    trainer.train()

    if accelerate.PartialState().is_main_process:
        with open(pathlib.Path(training_args.output_dir) / "training_args.json", "w", encoding="utf-8") as f:
            json.dump(training_args.to_dict(), f, ensure_ascii=False, indent=4)
