"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
"""

import math
import logging

from tqdm import tqdm
import numpy as np

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader

logger = logging.getLogger(__name__)

class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)
    # checkpoint settings
    ckpt_path = None
    num_workers = 0 # for DataLoader
    sample_interm = False
    interm_idxs = []

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

class Trainer:

    def __init__(self, model, config):
        self.model = model
        self.config = config

        # take over whatever gpus are on the system
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        logger.info("saving %s", self.config.ckpt_path)
        torch.save(raw_model.state_dict(), self.config.ckpt_path)

    def train(self, train_dataset, n_epochs, sample_interm=False,
              val_dataset=None, val_freq=0, restart_lr=False):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split, epoch_id=None):
            is_train = split == 'train'
            model.train(is_train)
            data = train_dataset if is_train else val_dataset
            loader = DataLoader(data, shuffle=True, pin_memory=True,
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

            losses = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
            for it, (x, y) in pbar:

                # place data on the correct device
                x = x.to(self.device)
                y = y.to(self.device)

                # forward the model
                with torch.set_grad_enabled(is_train):
                    if sample_interm:
                        for k in config.interm_idxs:
                            logits, _, _ = model(x[:, :k])
                            logits = logits[:, -1, :]
                            probs = torch.nn.functional.softmax(logits, dim=-1)
                            _, ix = torch.topk(probs, k=1, dim=-1)
                            x[:, k] = ix[:, 0]

                    logits, loss, attns = model(x, y)
                    loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())

                if is_train:

                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description(f"epoch {epoch+1}/{n_epochs}, iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

            assert len(losses) == np.ceil(len(data) / config.batch_size)
            epoch_loss = np.mean(losses)
            if not is_train:
                logger.info("test loss: %f", epoch_loss)

            return epoch_loss

        best_loss = float('inf')
        if not hasattr(self, 'tokens') or restart_lr:
            self.tokens = 0 # counter used for learning rate decay
        val_epochs = np.unique([np.ceil(n_epochs * i / val_freq)
                                for i in np.arange(val_freq) + 1])
        train_losses = []
        val_losses = []
        for epoch in range(n_epochs):
            train_losses.append(run_epoch('train', epoch))
            if val_dataset is not None and (epoch + 1) in val_epochs:
                val_losses.append(run_epoch('val'))

            # supports early stopping based on the test loss, or just save always if no test set is provided
            good_model = (val_dataset is None
                          or len(val_losses) == 0
                          or val_losses[-1] < best_loss)
            if self.config.ckpt_path is not None and good_model:
                best_loss = (val_losses[-1] if len(val_losses) > 0
                             else float('inf'))
                self.save_checkpoint()

        return train_losses, val_losses
