
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from dataclasses import dataclass
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer, AutoProcessor,
                          TrainingArguments, Trainer, default_data_collator)
from torch.utils.data import Dataset
import json

def load_jsonl(path):
    data = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

class VQADataset(Dataset):
    def __init__(self, data, tokenizer, max_len=2048):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        q = item.get("question", "")
        ans = item.get("answers", [""])[0]
        prompt = f"USER: {q}\nASSISTANT:"
        tok = self.tokenizer(prompt, truncation=True, max_length=self.max_len)
        labels = self.tokenizer(ans, truncation=True, max_length=64, add_special_tokens=False)["input_ids"]
        input_ids = tok["input_ids"] + labels
        attention_mask = tok["attention_mask"] + [1]*len(labels)
        label_ids = [-100]*len(tok["input_ids"]) + labels
        return {"input_ids": torch.tensor(input_ids),
                "attention_mask": torch.tensor(attention_mask),
                "labels": torch.tensor(label_ids)}

@dataclass
class InternVLConfig:
    model_name: str = "OpenGVLab/InternVL2_5-8B"
    output_dir: str = "internvl-output"
    lr: float = 1e-6
    batch_size: int = 2
    grad_accum: int = 4
    epochs: int = 3
    bf16: bool = True

class InternVLTrainer:
    def __init__(self, cfg, train_path, val_path):
        self.cfg = cfg
        self.train_data = load_jsonl(train_path)
        self.val_data = load_jsonl(val_path)
        self.processor = AutoProcessor.from_pretrained(cfg.model_name, trust_remote_code=True)
        self.tokenizer = self.processor.tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_name, device_map="auto",
                                                         torch_dtype=torch.bfloat16 if cfg.bf16 else torch.float16)

    def make_datasets(self):
        return (VQADataset(self.train_data, self.tokenizer),
                VQADataset(self.val_data, self.tokenizer))

    def train(self):
        train_ds, val_ds = self.make_datasets()
        args = TrainingArguments(output_dir=self.cfg.output_dir, learning_rate=self.cfg.lr,
                                 per_device_train_batch_size=self.cfg.batch_size,
                                 gradient_accumulation_steps=self.cfg.grad_accum,
                                 num_train_epochs=self.cfg.epochs, logging_steps=20,
                                 save_strategy="epoch", evaluation_strategy="epoch",
                                 bf16=self.cfg.bf16, report_to=[])
        trainer = Trainer(model=self.model, args=args, train_dataset=train_ds,
                          eval_dataset=val_ds, tokenizer=self.tokenizer,
                          data_collator=default_data_collator)
        trainer.train()
        trainer.save_model(self.cfg.output_dir)
