import numpy as np
import torch
import os

from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm


def write_logs(writer, logs, global_step):
    for key, val in logs.items():
        if key.split('/')[0] == 'img':
            writer.add_image(key, val, global_step)
        else:
            writer.add_scalar(key, val, global_step)


def train(model, model_cfg, **trainer_cfg):
    model.prepare_data()
    train_dataloader = model.train_dataloader()
    val_dataloader = model.val_dataloader()
    optimizer = model.configure_optimizers()

    if trainer_cfg['resume']:
        checkpoint = torch.load(os.path.join(trainer_cfg['model_dir'], 'checkpoint.pt'))
        epoch = checkpoint['epoch']
        n_update = checkpoint['n_update']
        best_val_loss = checkpoint['best_val_loss']
        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(checkpoint['model'])
    else:
        epoch, n_update, best_val_loss = 0, 0, 1e5

    writer = SummaryWriter(trainer_cfg['log_dir'])

    pbar = tqdm(total=trainer_cfg['updates_per_epoch'], desc='Epoch %d' % epoch, ncols=80)
    while epoch < trainer_cfg['epochs']:
        for batch in train_dataloader:
            batch['epoch'] = epoch
            output = model.training_step(batch)

            optimizer.zero_grad()
            output['loss'].backward()
            optimizer.step()

            if n_update % trainer_cfg['updates_per_logger'] == 0:
                write_logs(writer, output['log'], n_update)
            n_update += 1
            pbar.update()

            if n_update % trainer_cfg['updates_per_epoch'] == 0:
                pbar.close()
                epoch += 1
                checkpoint = {
                    'epoch': epoch,
                    'n_update': n_update,
                    'best_val_loss': best_val_loss,
                    'optimizer': optimizer.state_dict(),
                    'model': model.state_dict(),
                    'model_cfg': model_cfg,
                }
                torch.save(checkpoint, os.path.join(trainer_cfg['model_dir'], 'checkpoint.pt'))

                val_outputs = []
                for val_batch in tqdm(val_dataloader):
                    val_batch['epoch'] = epoch
                    val_output = model.validation_step(val_batch)
                    val_outputs.append(val_output)
                val_output = model.validation_epoch_end(val_outputs)
                write_logs(writer, val_output['log'], epoch)

                if epoch % trainer_cfg['epochs_per_save'] == 0:
                    torch.save(checkpoint, os.path.join(trainer_cfg['model_dir'], f'checkpoint_{epoch}.pt'))
                if trainer_cfg['save_best_model'] and val_output['avg_val_loss'] < best_val_loss:
                    best_val_loss = val_output['avg_val_loss']
                    torch.save(checkpoint, os.path.join(trainer_cfg['model_cir'], 'checkpoint_best.pt'))

                pbar = tqdm(total=trainer_cfg['updates_per_epoch'], desc='Epoch %d' % epoch, ncols=80)
