from transformers import AutoProcessor, PreTrainedModel, ProcessorMixin, TrainingArguments, Trainer
from datasets import Dataset, load_from_disk

from src.common.dataset import DataCollator
from src.train_tools.experiment import Experiment
from src.qwen2_5.fa_model import Qwen2_5_VLForConditionalGenerationWithHeatmap


class InjectionExperiment(Experiment):
    eval_dataset: Dataset = None
    model: PreTrainedModel = None
    processor: ProcessorMixin = None
    train_args: TrainingArguments = None
    data_collator: callable = None

    def prepare_model(self) -> tuple[PreTrainedModel, ProcessorMixin]:
        model = Qwen2_5_VLForConditionalGenerationWithHeatmap.from_pretrained(
            self.cfg.model.name,
            **self.cfg.model.kwargs
        )
        processor = AutoProcessor.from_pretrained(self.cfg.model.name)
        processor.tokenizer.padding_side = 'left'

        return model, processor

    def prepare_dataset(self, test_size: float = 0.2):
        dataset = load_from_disk(self.cfg.dataset.path)
        split_dataset = dataset.train_test_split(test_size=test_size, seed=42)
        self.train_dataset = split_dataset["train"]
        self.eval_dataset = split_dataset["test"]

    def prepare_for_training(self):
        self.model, self.processor = self.prepare_model()
        self.prepare_dataset()
        self.train_args = TrainingArguments(**self.cfg.trainer)
        self.data_collator = DataCollator(self.processor, **self.cfg.dataset.kwargs)
