import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model, PeftModel
accelerator = Accelerator()
import os
os.environ["WANDB__SERVICE_WAIT"] = "3000"
os.environ["WANDB_PROJECT"] = "SFT_BioT5"


# Load your dataset
def load_json_dataset(file_path):
    dataset = load_dataset('json', data_files=file_path)
    return dataset


# Preprocessing function to tokenize inputs and targets
def preprocess_data(examples, tokenizer, max_input_length=512, max_target_length=2560):
    inputs = examples['instruction']
    targets = examples['output']
    model_inputs = tokenizer(inputs, max_length=max_input_length, padding='max_length', truncation=True)
    labels = tokenizer(targets, max_length=max_target_length, padding='max_length', truncation=True).input_ids
    labels = [[(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels]
    model_inputs['labels'] = labels
    return model_inputs


file_path = 'SFT_dataset.json'
model_name = 'QizhiPei/biot5-plus-base-chebi20' 
output_dir = './SFT_model'
batch_size = 2 # 2 x (the number of GPUs = 4) = 8
num_epochs = 80


tokenizer = T5Tokenizer.from_pretrained('QizhiPei/biot5-plus-base-chebi20')
model = T5ForConditionalGeneration.from_pretrained(model_name)
dataset = load_json_dataset(file_path)
tokenized_datasets = dataset['train'].map(lambda x: preprocess_data(x, tokenizer), batched=True, num_proc=8)


training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="no",
    learning_rate=5e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    warmup_ratio=0.05,
    weight_decay=0,
    save_total_limit=3,
    lr_scheduler_type='cosine',
    logging_dir='./logs',
    logging_steps=10,
)
trainer = accelerator.prepare(Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=None,
    tokenizer=tokenizer,
))
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
