from datasets import load_dataset
from transformers import TrainingArguments, AutoConfig, AutoTokenizer
from model.pretrain_model_esm import ProteinTextCLIPForPretrain, ProteinTextCLIPConfig
from model.pretrain_model_esm_seq_only import ProteinTextCLIPForPretrainSequenceOnly
from trainer.pretrain_trainer import CLIPPretrainTrainer
from utils import DataCollatorForProteinTextCLIPPretrain


class PretrainTask(object):
    def __init__(self, run_config):
        self.run_config = run_config
        self.task_model = self.build_task_model()
        self.dataset = self.build_dataset()
        self.train_args = self.build_train_args()
        self.trainer = self.build_trainer()

    def build_task_model(self):
        raise NotImplementedError()

    def build_dataset(self):
        raise NotImplementedError()

    def build_train_args(self):
        raise NotImplementedError()

    def build_trainer(self):
        raise NotImplementedError()

    def run(self):
        self.trainer.train()


class ProteinTextCLIPPretrainTask(PretrainTask):
    def __init__(self, run_config):
        self.protein_model_config = AutoConfig.from_pretrained(run_config.protein_model_name)
        self.text_model_config = AutoConfig.from_pretrained(run_config.text_model_name)
        self.protein_tokenizer = AutoTokenizer.from_pretrained(run_config.protein_model_name, use_fast=False)
        self.text_tokenizer = AutoTokenizer.from_pretrained(run_config.text_model_name, use_fast=False)
        self.pdb_h5_path = f'{run_config.data_path}/{run_config.dataset}/pdb.h5'
        super().__init__(run_config)

    def build_task_model(self):
        task_model_config = ProteinTextCLIPConfig(
            protein_model_config=self.protein_model_config,
            text_model_config=self.text_model_config,
            projection_dim=self.run_config.projection_dim,
        )
        if self.run_config.sequence_only:
            task_model = ProteinTextCLIPForPretrainSequenceOnly(task_model_config)
        else:
            task_model = ProteinTextCLIPForPretrain(task_model_config)
        return task_model

    def build_dataset(self):
        dataset = load_dataset("json", data_files={
            'train': f'{self.run_config.data_path}/{self.run_config.dataset}/train.json',
        })
        return dataset

    def build_train_args(self):
        return TrainingArguments(
            output_dir=self.run_config.output_path,
            do_eval=False,
            save_strategy="epoch",
            logging_strategy="steps",
            logging_steps=20,
            per_device_train_batch_size=self.run_config.batch_size,
            per_device_eval_batch_size=self.run_config.batch_size,
            num_train_epochs=self.run_config.num_epochs,
            weight_decay=self.run_config.weight_decay,
            fp16=self.run_config.fp16,
            push_to_hub=False,
            learning_rate=self.run_config.lr,
            report_to=["wandb"],
            warmup_ratio=self.run_config.warmup_ratio,
            remove_unused_columns=False,
            dataloader_num_workers=4,
            deepspeed=self.run_config.deepspeed,
        )

    def build_trainer(self):
        return CLIPPretrainTrainer(
            model=self.task_model,
            args=self.train_args,
            data_collator=DataCollatorForProteinTextCLIPPretrain(self.protein_tokenizer,
                                                                 self.text_tokenizer,
                                                                 self.pdb_h5_path,
                                                                 sequence_only=self.run_config.sequence_only,
                                                                 mlm_probability=getattr(self.run_config,
                                                                                         "mlm_probability", 0.0), ),
            train_dataset=self.dataset["train"],
            protein_model_fixed=self.run_config.protein_model_fixed,
            text_model_fixed=self.run_config.text_model_fixed,
            lr_ratio=self.run_config.lr_ratio,
        )
