import argparse
import os
import sys
import time
from types import SimpleNamespace

import horovod.torch as hvd
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.utils.data.distributed as torch_dist
from torchvision.utils import save_image

from code import arch_v2
from code.data import get_dataset
from code.utils import Logger, seed_everything, EMA


# Returns all weight norm scales parameters in the given model.
def get_wn_params(model):
    return [p for n, p in model.named_parameters() if n.endswith('weight_g')]


def main(args):
    # Horovod setup
    hvd.init()
    torch.cuda.set_device(hvd.local_rank())


    # Create model
    model_cls = {
        'mnist': arch_v2.RealNVP_MNIST,
        'cifar10': arch_v2.RealNVP_CIFAR10,
        'cifar10_5bit': arch_v2.RealNVP_CIFAR10_5bit,
        'celebahq64_5bit': arch_v2.RealNVP_CelebAHQ64_5bit,
    }[args.task]
    if hvd.rank() == 0:
        seed_everything(args.seed)
    model = model_cls()
    hp = model.hp


    # Setup output directory and load model
    if args.ckpt is not None:
        assert os.path.isfile(args.ckpt)
        dd = torch.load(args.ckpt)
        start_epoch = dd['epoch'] + 1
        total_steps = dd['total_steps']
        if hvd.rank() == 0:
            print(f'Resuming training from ckpt: {args.ckpt}')
            model.load_state_dict(dd['model_state_dict'])
    else:
        if not os.path.exists(args.root_dir) or len(os.listdir(args.root_dir)) == 0:
            start_epoch = 1
            total_steps = 0
            if hvd.rank() == 0:
                print(f'Training in directory {args.root_dir}')
                os.makedirs(args.root_dir, exist_ok=True)

        else:
            if hvd.rank() == 0:
                print('Specified directory not empty! Terminating.')
                sys.exit(0)
    if hvd.rank() == 0:
        print(hp)
    model.cuda()
    wn_params = get_wn_params(model)


    # Load data
    loader_kwargs = {'num_workers': args.num_workers, 'pin_memory': True}
    # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
    # issues with Infiniband implementations that are not fork-safe
    if (loader_kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
            mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
        loader_kwargs['multiprocessing_context'] = 'forkserver'
        if hvd.rank() == 0:
            print(f'Using multiprocessing context: {loader_kwargs["multiprocessing_context"]}')

    if args.task.startswith('imagenet'):
        dataset_tr = get_dataset(args.task, split='train', data_root=args.data_root)
        dataset_val = get_dataset(args.task, split='valid', data_root=args.data_root)
    elif args.task.startswith('oord_imagenet'):
        dataset_tr = get_dataset(args.task, split='train', data_root=args.data_root, mmap=True)
        dataset_val = get_dataset(args.task, split='valid', data_root=args.data_root)
    elif args.task.startswith('cifar10'):
        dataset_tr = get_dataset('cifar10', split='train', data_root=args.data_root)
        dataset_val = get_dataset('cifar10', split='test', data_root=args.data_root)
    elif args.task.startswith('celebahq'):
        if len(args.task.split('_')) > 1:
            dataset_name = args.task.split('_')[0]
        else:
            dataset_name = args.task
        dataset_tr = get_dataset(dataset_name, split='train', data_root=args.data_root)
        dataset_val = get_dataset(dataset_name, split='valid', data_root=args.data_root)
    elif args.task == 'mnist':
        dataset_tr = get_dataset(args.task, split='train', data_root=args.data_root)
        dataset_val = get_dataset(args.task, split='test', data_root=args.data_root)
    else:
        raise ValueError(f'Invalid task {args.task}')

    if args.data_limit > 0:
        dataset_tr = dataset_tr[:args.data_limit]
        dataset_val = dataset_val[:args.data_limit]

    batch_size = hp.full_batch_size // hvd.size()
    sampler_tr = torch_dist.DistributedSampler(dataset_tr, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=True)
    loader_tr = torch.utils.data.DataLoader(dataset_tr, batch_size=batch_size, sampler=sampler_tr, **loader_kwargs)
    sampler_val = torch_dist.DistributedSampler(dataset_val, num_replicas=hvd.size(), rank=hvd.rank(), shuffle=False)
    loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, sampler=sampler_val, **loader_kwargs)


    # Create optimizer
    current_lr = hp.learning_rate
    optimizer = torch.optim.Adam(model.parameters(), lr=current_lr)
    optimizer = hvd.DistributedOptimizer(optimizer,
                                         named_parameters=[(n,p) for n,p in model.named_parameters()])
    if hvd.rank() == 0 and args.ckpt is not None:
        print(f'Loading optimizer weights from ckpt...')
        optimizer.load_state_dict(dd['optimizer_state_dict'])
        current_lr = dd['learning_rate']


    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)


    # Setup EMA
    ema = EMA(model, args.ema_decay)
    if args.ckpt is not None:
        if hvd.rank() == 0:
            print(f'Loading EMA weights from ckpt...')
        ema.load_state_dict(dd['ema_state_dict'])

    # Misc
    sample_noise = model.sample_prior(64 // hvd.size(), temp=0.7).cuda()

    @torch.no_grad()
    def dump_samples(epoch, step):
        model.eval()
        samples = model(sample_noise, inverse=True)[0].clamp(0, 1)
        samples = hvd.allgather(samples)
        save_image(samples.cpu(),
                   os.path.join(args.root_dir,
                                f'samples_ep={epoch:03d}_step={step:07d}.png'),
                   nrow=8, pad_value=1, range=(0, 1))


    # Print some info
    params_grad = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
    params_all = sum([np.prod(p.shape) for p in model.parameters()])
    if hvd.rank() == 0:
        print(f'  >>> Trainable/total params: {params_grad} / {params_all}')
        print(f'  >>> Horovod local_size / size = {hvd.local_size()} / {hvd.size()}')
        print(f'  >>> Per-GPU / Total batch size = {batch_size} / {hp.full_batch_size}')
        print('Starting training!\n')

    start_time = time.time()
    stats = SimpleNamespace(
        hparams                 = vars(hp),
        args                    = vars(args),

        loss                    = [],
        nll                     = [],
        bpd                     = [],
        wn_l2                   = [],
        grad_norm               = [],

        steps_per_sec           = [],
        total_time              = [],
        epoch                   = [],
        learning_rate           = [],

        val_nll                 = [],
        val_bpd                 = [],
        val_steps_per_sec       = [],
    )

    if hvd.rank() == 0:
        logger = Logger(args.root_dir)
        hp.dump(os.path.join(args.root_dir, 'hparams.json'))

    local_steps = 0
    quantization_logdet = -np.prod(hp.image_shape) * np.log(2 ** hp.n_bits)
    for epoch in (range(start_epoch, hp.max_epoch+1) if hp.max_epoch > 0 else itertools.count(start_epoch)):
        model.train()
        sampler_tr.set_epoch(epoch)

        # LR decay
        if epoch in hp.learning_rate_decay:
            current_lr *= 0.5
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr
            if hvd.rank() == 0:
                print(f'Reduced learning rate by half at epoch {epoch}')

        for batch_idx, x in enumerate(loader_tr):
            total_steps += 1
            local_steps += 1

            x = x.float().to('cuda', non_blocking=True)

            if hvd.rank() == 0 and total_steps == 1:
                save_image(x / 255., os.path.join(args.root_dir, 'train_batch.png'), nrow=8, range=(0, 1))

            # Adjust number of bits and dequantize
            if hp.n_bits < 8:
                x = torch.floor(x / 2 ** (8 - hp.n_bits))
            x = (x + torch.rand_like(x)) / (2 ** hp.n_bits)
            assert (x.min() >= 0).all() and (x.max() <= 1).all()

            nll = -(model.log_prob(x)[0] + quantization_logdet).mean()

            # Weight norm regularization
            wn_l2 = 0.
            for p in wn_params:
                wn_l2 += p.norm()
            wn_l2 = hp.l2_coeff * wn_l2
            loss = nll + wn_l2

            optimizer.zero_grad()
            loss.backward()
            optimizer.synchronize()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.clip_grad_norm)
            with optimizer.skip_synchronize():
                optimizer.step()

            # EMA
            ema(model, total_steps)

            # Monitoring
            total_time = time.time() - start_time
            bpd = nll.item() / np.log(2) / np.prod(hp.image_shape)
            stats.loss.append(loss.item())
            stats.nll.append(nll.item())
            stats.bpd.append(bpd.item())
            stats.wn_l2.append(wn_l2.item())
            stats.grad_norm.append(grad_norm)
            stats.steps_per_sec.append(local_steps / total_time)
            stats.total_time.append(total_time)
            stats.epoch.append(epoch)
            stats.learning_rate.append(current_lr)

            if hvd.rank() == 0 and (total_steps % hp.print_freq == 0 or batch_idx == len(loader_tr) - 1):
                print(f'\rep {epoch:03d} step {batch_idx+1:06d}/{len(loader_tr):06d} '
                      f'total_steps {total_steps:07d} ',
                      f'loss {stats.loss[-1]:.3f} ',
                      f'nll {stats.nll[-1]:.3f} ',
                      f'bpd {stats.bpd[-1]:.3f} ',
                      f'wn_l2 {stats.wn_l2[-1]:.3f} ',
                      f'grad_norm {stats.grad_norm[-1]:.3f} ',
                      f'time {stats.total_time[-1]:.2f} sec ',
                      f'steps/sec {stats.steps_per_sec[-1]:.2f} ', end='', flush=True)

            if hvd.rank() == 0 and total_steps % hp.log_freq == 0:
                logger.log_scalars({
                    'train/loss': stats.loss[-1],
                    'train/nll': stats.nll[-1],
                    'train/bpd': stats.bpd[-1],
                    'train/wn_l2': stats.wn_l2[-1],
                    'train/grad_norm': stats.grad_norm[-1],
                    'misc/epoch': stats.epoch[-1],
                    'misc/learning_rate': stats.learning_rate[-1],
                    'misc/total_time': stats.total_time[-1],
                    'misc/steps_per_sec': stats.steps_per_sec[-1],
                }, total_steps)

        if hvd.rank() == 0:
            print()

        if hvd.rank() == 0 and epoch % hp.ckpt_freq == 0:
            dump_dict = {
                'stats': vars(stats),
                'hparams': vars(hp),
                'args': vars(args),
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'ema_state_dict': ema.state_dict(),

                'epoch': epoch,
                'total_steps': total_steps,
                'learning_rate': current_lr,
            }
            torch.save(dump_dict, os.path.join(args.root_dir, f'ckpt_ep={epoch:03d}_step={total_steps:06d}.pt'))
            print(f'[CHECKPOINT] Saved model at epoch {epoch:03d} total_steps {total_steps:06d}', flush=True)

        ##### VALIDATION #####
        total_nll = 0.0
        total_count = 0
        model.eval()

        ema.assign(model)

        if epoch % hp.sample_freq == 0:
            dump_samples(epoch, total_steps)

        val_start_time = time.time()
        with torch.no_grad():
            for batch_idx, x in enumerate(loader_val):
                if hvd.rank() == 0 and batch_idx == 0:
                    save_image(x / 255., os.path.join(args.root_dir, 'test_batch.png'), nrow=8, range=(0, 1))

                x = x.float().to('cuda', non_blocking=True)
                if hp.n_bits < 8:
                    x = torch.floor(x / 2 ** (8 - hp.n_bits))
                x = (x + torch.rand_like(x)) / (2 ** hp.n_bits)
                assert (x.min() >= 0).all() and (x.max() <= 1).all()

                nll = -(model.log_prob(x)[0] + quantization_logdet)
                nll = hvd.allgather(nll)
                total_nll += nll.sum().item()
                total_count += len(nll)
            assert total_count == len(dataset_val)
            stats.val_nll.append(total_nll / total_count)
            stats.val_bpd.append(stats.val_nll[-1] / np.log(2) / np.prod(hp.image_shape))
            stats.val_steps_per_sec.append(len(loader_val) / (time.time()-val_start_time))

        ema.restore(model)

        if hvd.rank() == 0:
            print(f'[VALIDATION] ep {epoch:03d} nll {stats.val_nll[-1]:.4f} '
                  f'bpd {stats.val_bpd[-1]:.4f} total_cnt {total_count} '
                  f'steps/sec {stats.val_steps_per_sec[-1]:.2f}', flush=True)
            logger.log_scalars({
                'val_ema/nll': stats.val_nll[-1],
                'val_ema/bpd': stats.val_bpd[-1],
                'val_ema/steps_per_sec': stats.val_steps_per_sec[-1], }, total_steps)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('task', type=str)
    parser.add_argument('--root_dir', type=str, required=True)
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--data_root', type=str, default=None)
    parser.add_argument('--ema_decay', type=float, default=0.9995)
    parser.add_argument('--data_limit', type=int, default=-1)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--num_workers', type=int, default=8)
    args = parser.parse_args()

    main(args)

