import os
import json
import typing

import numpy as np
import torch
import datasets
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorForLanguageModeling

from src.tokenizer import TokenizerWrapper

if __name__ == '__main__':
    tokenizer = TokenizerWrapper(
        os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data/smiles_bpe_tokenizer_543.model')
    )

    model = GPT2LMHeadModel(
        config=GPT2Config(
            vocab_size=546,
            n_embd=256,
            n_layer=6,
            n_head=8,
            n_positions=256,
            bos_token_id=tokenizer.tokenizer.bos_id(),
            eos_token_id=tokenizer.tokenizer.eos_id(),
        )
    )
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Load synthetic data (PI1M and PolyTAO corpora)
    synthetic_data = []  # Placeholder for synthetic data
    
    # Load experimental data (PolyInfo data)
    experimental_data = []  # Placeholder for experimental data

    experimental_train_idxs = np.random.choice(
        np.arange(len(experimental_data)), 
        int(0.9*len(experimental_data)), 
        replace=False
    )
    experimental_test_mask = [True for _ in range(len(experimental_data))]
    for cur in experimental_train_idxs:
        experimental_test_mask[cur] = False
    experimental_test_idxs = np.arange(len(experimental_data))[experimental_test_mask]

    assert experimental_train_idxs.shape[0] + experimental_test_idxs.shape[0] == len(experimental_data)
    assert np.intersect1d(experimental_train_idxs, experimental_test_idxs).shape[0] == 0

    with open('experimental_train_test_split.json', 'w') as fp:
        json.dump(
            {
                'train' : [experimental_data[idx] for idx in experimental_train_idxs],
                'test' : [experimental_data[idx] for idx in experimental_test_idxs],
            }
        )

    dataset = datasets.Dataset.from_dict(
        {
            'text' : synthetic_data + [experimental_data[idx] for idx in experimental_train_idxs] * 5
        }
    )

    tokenized_data = dataset.map(lambda x: tokenizer(x['text']), batched=True, remove_columns=["text"])

    split_dataset = tokenized_data.train_test_split(test_size=0.1)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    training_args = TrainingArguments(
        output_dir='./poge_checkpoints/pretrain',
        overwrite_output_dir=True,
        num_train_epochs=20,
        per_device_train_batch_size=384,
        per_device_eval_batch_size=384,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_dir='./poge_checkpoints/pretrain/logs',
        logging_steps=100,
        learning_rate=5e-5,
        weight_decay=0.01,
        warmup_steps=500,
        save_total_limit=20,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=split_dataset["train"],
        eval_dataset=split_dataset["test"],
        data_collator=data_collator,
    )

    trainer.train()

    trainer.save_model('./poge_checkpoints/pretrain')
    tokenizer.save_pretrained('./poge_checkpoints/pretrain')
