import fire

from omegaconf import OmegaConf
from transformers import PreTrainedModel, ProcessorMixin, Trainer

from src.train_tools.callbacks import ClearMLCallback, SaveCustomWeightsCallback
from src.train_tools.initiator import init_transformer_block_weights, setup_lora, set_trainable_parameters
from src.injection_experiment import InjectionExperiment


class Training(InjectionExperiment):

    def prepare_model(self) -> tuple[PreTrainedModel, ProcessorMixin]:
        model, processor = super().prepare_model()

        init_transformer_block_weights(model.visual.heat_block)
        init_transformer_block_weights(model.heat_embedding) 

        model = setup_lora(model, OmegaConf.to_object(self.cfg.lora))
        model = set_trainable_parameters(model)

        return model, processor
    
    def get_callbacks(self):
        return [ClearMLCallback(self.task), SaveCustomWeightsCallback()]

class Tuning(InjectionExperiment):

    def prepare_model(self) -> tuple[PreTrainedModel, ProcessorMixin]:
        model, processor = super().prepare_model()
        model = setup_lora(model, OmegaConf.to_object(self.cfg.lora))

        return model, processor
    
    def get_callbacks(self):
        return [ClearMLCallback(self.task)]




def main(config):
    Experiment = Tuning if "tune" in config else Training
    experiment = Experiment(config) 
    experiment.prepare_for_training()
    experiment.task_init()

    trainer = Trainer(
        model=experiment.model,
        processing_class=experiment.processor,
        train_dataset=experiment.train_dataset,
        eval_dataset=experiment.eval_dataset,
        data_collator=experiment.data_collator,
        args=experiment.train_args,
        callbacks=experiment.get_callbacks()
    )

    trainer.train()


if __name__ == "__main__":
    fire.Fire(main)
