#########################################################
############# Script for TCR-BART No Pretrain ###########
#########################################################


from datasets import load_dataset, Dataset
from transformers import BartConfig, BartForConditionalGeneration, Trainer, TrainingArguments, BartTokenizer
from accelerate import Accelerator
import wandb
import os

### Set up logging via Weights & Biases


# Load source and target files separately
source_dataset = load_dataset('text', data_files='../data/pmhc_stringent_split/train_source.txt')
target_dataset = load_dataset('text', data_files='../data/pmhc_stringent_split/train_target.txt')

# Load the validation files
val_source = load_dataset('text', data_files='../data/pmhc_stringent_split/val_source.txt')
val_target = load_dataset('text', data_files='../data/pmhc_stringent_split/val_target.txt')

# Ensure source and target datasets have the same size
assert len(source_dataset["train"]) == len(target_dataset["train"])
assert len(val_source['train'])==len(val_target['train'])

# Merge source and target datasets
dataset = Dataset.from_dict({
    'src_texts': [example['text'] for example in source_dataset['train']],
    'tgt_texts': [example['text'] for example in target_dataset['train']]
})


val_dataset = Dataset.from_dict({
    'src_texts': [example['text'] for example in val_source['train']],
    'tgt_texts': [example['text'] for example in val_target['train']]
})

# Create tokenizer 
tokenizer = BartTokenizer('vocab.json', 'merges.txt', bos_token='[SOS]', eos_token='[EOS]', sep_token='[SEP]', cls_token='[CLS]', unk_token='[UNK]', pad_token='[PAD]', mask_token='[MASK]')

def labeled_tokenize_function(example, src_max_len=52, trg_max_len=24):
    # Split each element of the list into two parts
    sentences = [text.split(" ", 1) for text in example["src_texts"]]
    # Create lists for source and target texts
    pep = [s[0] for s in sentences]
    pseudo = [s[1] if len(s) > 1 else "" for s in sentences]  # Ensure that the second part exists
    
    target_text = example["tgt_texts"]

    # Tokenize source and target texts separately
    source_tokens = tokenizer(pep, pseudo, padding="max_length", return_tensors='pt', truncation=True, max_length=src_max_len)
    target_tokens = tokenizer(target_text, padding="max_length", return_tensors='pt', truncation=True, max_length=trg_max_len)

    # Apply padding to source and target sequences
    padded_source = {
        "input_ids": source_tokens["input_ids"][:, :src_max_len],
        "attention_mask": source_tokens["attention_mask"][:, :src_max_len]
    }
    padded_target = {
        "input_ids": target_tokens["input_ids"][:, :trg_max_len],
        "attention_mask": target_tokens["attention_mask"][:, :trg_max_len]
    }

    return {
        "input_ids": padded_source["input_ids"],
        "attention_mask": padded_source["attention_mask"],
        "labels": padded_target["input_ids"]
    }

tokenized_dataset = dataset.map(labeled_tokenize_function, batched=True)
tokenized_dataset.set_format("torch", columns=['input_ids', 'attention_mask', 'labels'])

tokenized_val = val_dataset.map(labeled_tokenize_function, batched=True)
tokenized_val.set_format("torch", columns=['input_ids', 'attention_mask', 'labels'])

### Instantiating the Model and Trainer Classes
config = BartConfig(
    vocab_size=28,
    max_position_embeddings=512,
    d_model=768,
    pad_token_id=0,
    bos_token_id=1,
    eos_token_id=2,
    sep_token_id=4,
    decoder_start_token_id=1,
    encoder_layers=6,
    decoder_layers=6,
    output_hidden_states=True,
    output_scores=True,
    output_attentions=True,
    add_cross_attention=True,
    top_k=3
)

# initialize the model
model = BartForConditionalGeneration(config)

# Make the training Args
training_args = TrainingArguments(
    output_dir="../model_checkpoints/no_pretraining_model_25eps",
    overwrite_output_dir=True,
    num_train_epochs=25,
    per_device_train_batch_size=128,
    save_steps=1000,
    do_eval=True,
    evaluation_strategy='steps',
    eval_steps=100,
    learning_rate=3e-04,
    logging_steps=10,
    save_total_limit=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
)

accelerator = Accelerator()
trainer = accelerator.prepare(trainer)

trainer.train()
wandb.finish()
