"""redo tokenizer before starting v2"""

import logging
import os

import transformers

from datasets import Dataset
from transformers import Trainer, TrainingArguments

_LOG = logging.getLogger("vqt2g_logger")


def make_hf_dataset(tokenized_dataset):
    """Make the tokenized dataset into the right format for the Dataset class"""
    # This isn't the 'right' way, but it works
    tok_dict = {}
    tok_dict["input_ids"] = []
    tok_dict["attention_mask"] = []
    tok_dict["labels"] = []

    for item in tokenized_dataset:
        tok_dict["input_ids"].append(item)
        tok_dict["attention_mask"].append([1] * len(item))
        tok_dict["labels"].append(item)
    return tok_dict


def train_transformer(
    tokenized_train_dataset,
    tokenized_test_dataset,
    output_dir,
    vocab_size,
    model_max_length,
    epochs=100,
    batch_size=32,
    eval_steps=250,
    learning_rate=5e-4,
    checkpoint=None,
    model_embedding_size=64,
    model_num_layers=4,
    model_num_heads=4,
    max_checkpoints=10,
):
    """Train a GPT2 style transformer.

    Args:
      tokenized_train_dataset:
      tokenized_test_dataset:
      output_dir:
      vocab_size:
      model_max_length:
      epochs: (Default value = 10)
      batch_size: (Default value = 32)
      eval_steps: (Default value = 250)
      learning_rate: (Default value = 5e-4)
      checkpoint: (Default value = None)
      model_embedding_size: (Default value = 64)
      model_num_layers: (Default value = 4)
      model_num_heads: (Default value = 4)
      max_checkpoints:  (Default value = 10)

    Returns:

    """

    # Turn train/test into Dataset objects
    tokenized_train_dataset = make_hf_dataset(tokenized_train_dataset)
    tokenized_train_dataset = Dataset.from_dict(tokenized_train_dataset)

    tokenized_test_dataset = make_hf_dataset(tokenized_test_dataset)
    tokenized_test_dataset = Dataset.from_dict(tokenized_test_dataset)

    if checkpoint is None:
        # Train from scratch, custom GPT2 model config
        config = transformers.GPT2Config(
            vocab_size=vocab_size,
            n_positions=model_max_length,
            n_ctx=model_max_length,
            n_embd=model_embedding_size,
            n_layer=model_num_layers,
            n_head=model_num_heads,
            bos_token_id=0,
            eos_token_id=0,
            resid_pdrop=0.2,
            embd_pdrop=0.2,
            attn_pdrop=0.2,
        )
        model = transformers.GPT2LMHeadModel(config=config)
    else:
        model = transformers.GPT2LMHeadModel.from_pretrained(checkpoint)

    # Training args
    training_args = TrainingArguments(
        output_dir=output_dir,
        logging_dir=os.path.join(output_dir, "logs"),
        num_train_epochs=epochs,
        learning_rate=learning_rate,  # 5e-5 default
        overwrite_output_dir=True,
        do_train=True,
        evaluation_strategy="steps",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        save_total_limit=max_checkpoints,
        # load_best_model_at_end=True,
        eval_steps=eval_steps,
        save_steps=eval_steps * 5,
        max_steps=-1,
        # dataloader_drop_last=True,
        disable_tqdm=False,
        log_level="info",
        logging_steps=eval_steps,
    )

    # Train
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train_dataset,
        eval_dataset=tokenized_test_dataset,
    )

    _LOG.info("Starting training")
    trainer.train()
    trainer.save_model()
    _LOG.info(f"Finished training model {output_dir}")
