
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from InternVL_Trainer import VQADataset, load_jsonl, InternVLConfig
from transformers import Trainer, TrainingArguments, default_data_collator

class DataBuilder:
    def __init__(self, path, tokenizer):
        self.data = load_jsonl(path)
        self.tokenizer = tokenizer

    def build(self):
        return VQADataset(self.data, self.tokenizer)

class ModelBuilder:
    def __init__(self, cfg):
        from transformers import AutoModelForCausalLM, AutoProcessor
        self.cfg = cfg
        self.processor = AutoProcessor.from_pretrained(cfg.model_name, trust_remote_code=True)
        self.tokenizer = self.processor.tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_name, device_map="auto")

class ThreePartTrainer:
    def __init__(self, cfg, train_path, val_path):
        self.cfg = cfg
        self.model_builder = ModelBuilder(cfg)
        self.tokenizer = self.model_builder.tokenizer
        train_ds = DataBuilder(train_path, self.tokenizer).build()
        val_ds = DataBuilder(val_path, self.tokenizer).build()
        args = TrainingArguments(output_dir=cfg.output_dir, learning_rate=cfg.lr,
                                 per_device_train_batch_size=cfg.batch_size,
                                 gradient_accumulation_steps=cfg.grad_accum,
                                 num_train_epochs=cfg.epochs, report_to=[])
        self.trainer = Trainer(model=self.model_builder.model, args=args,
                               train_dataset=train_ds, eval_dataset=val_ds,
                               tokenizer=self.tokenizer, data_collator=default_data_collator)

    def train(self):
        self.trainer.train()
        self.trainer.save_model(self.cfg.output_dir)
