import os
import json
import typing

import numpy as np
import torch
import datasets
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorForLanguageModeling

from src.tokenizer import TokenizerWrapper

if __name__ == '__main__':
    tokenizer = TokenizerWrapper(
        os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data/smiles_bpe_tokenizer_543.model')
    )

    model = GPT2LMHeadModel.from_pretrained('poge_checkpoints/pretrain')
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Load experimental data (PolyInfo data)
    experimental_data = []  # Placeholder for experimental data

    train_test_split = None
    with open('experimental_train_test_split.json', 'r') as fp:
        train_test_split = json.load(fp)

    experimental_train_data = train_test_split['train']
    experimental_test_data = train_test_split['test']

    train_dataset = datasets.Dataset.from_dict(
        {
            'text': experimental_train_data
        }
    )
    val_dataset = datasets.Dataset.from_dict(
        {
            'text': experimental_test_data
        }
    )

    train_tok_dataset = train_dataset.map(lambda x: tokenizer(x['text']), batched=True, remove_columns=['text'])
    val_tok_dataset = val_dataset.map(lambda x: tokenizer(x['text']), batched=True, remove_columns=['text'])

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    training_args = TrainingArguments(
        output_dir='./poge_checkpoints/sft',
        overwrite_output_dir=True,
        num_train_epochs=50,
        per_device_train_batch_size=384,
        per_device_eval_batch_size=384,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_dir='./poge_checkpoints/sft/logs',
        logging_steps=100,
        learning_rate=5e-5,
        weight_decay=0.01,
        warmup_steps=50,
        save_total_limit=50,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_tok_dataset,
        eval_dataset=val_tok_dataset,
        data_collator=data_collator,
    )

    trainer.train()

    trainer.save_model('./poge_checkpoints/sft')
    tokenizer.save_pretrained('./poge_checkpoints/sft')
