import time
import os
import gc
import sys
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist

from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from ncpn.utils import *
from ncpn.model import NCPN
from ncpn.sde_lib import *
from PIL import Image

def sample(model, obs, T, time_cond, t_scale, test=False, eps=EPS, c_dim=None):
    sample_op = lambda x : sample_from_discretized_mix_logistic(x, 10)
    model.eval()

    if c_dim is not None:
        c = nn.functional.one_hot(torch.arange(25) % c_dim, num_classes=c_dim).cuda().float()
    else:
        c = None

    model_fn = (lambda x, t: model(x, t * t_scale, c=c)) if time_cond else (lambda x, t: model(x, c=c))
    data = torch.zeros(25, obs[0], obs[1], obs[2]).cuda()
    t = torch.linspace(eps, T, 5)[:,None].tile(1, 5).reshape(-1).cuda()
    data = data.cuda()
    with torch.no_grad():
        for i in range(obs[1]):
            for j in range(obs[2]):
                out   = model_fn(data, t)
                out_sample = sample_op(out)
                data[:, :, i, j] = out_sample[:, :, i, j]

            if test:
                break
    return data

def get_score_fn(model, loss_op, create_graph=True):
    def score_fn(data, t):
        data.requires_grad_(True)
        logits = model(data, t)
        logp = -loss_op(data, logits).sum()
        if create_graph:
            grad = torch.autograd.grad(logp, data, create_graph=True)[0]
        else:
            logp.backward()
            grad = data.grad

        if not grad.isfinite().all():
            print("score is nan!")

        return grad

    return score_fn

def get_loss_fn(sde_cls, T, bs_min, bs_max, time_cond, t_scale, discrete, class_cond, eps=EPS):
    loss_op_cont = lambda real, fake: discretized_mix_logistic_loss(real, fake, per_sample=True, discrete=False)
    loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake, per_sample=True, discrete=discrete)
    sde = sde_cls(bs_min, bs_max)
    model_fn = (lambda x, t: model(x, t * t_scale)) if time_cond else (lambda x, t: model(x))

    def loss_fn(x, model, likelihood_only=False, c=None):
        if likelihood_only:
            t = torch.zeros((len(x),), device=x.device) + eps
        else:
            t = torch.rand(x.shape[0], device=x.device) * (T - eps) + eps

        z = torch.randn_like(x)
        mean, std = sde.marginal_prob(x, t)
        perturbed_data = mean + std[:, None, None, None] * z
        t = t * t_scale if time_cond else None

        logits = model(perturbed_data, cond1=t, c=c)

        if discrete:
            loss = loss_op(perturbed_data, logits).mean() / (perturbed_data.shape[1] * np.log(2))
        else:
            loss = ll_to_bpd(-loss_op(perturbed_data, logits).sum(dim=(1, 2))).mean()

        return loss

    return loss_fn

def cleanup():
    dist.destroy_process_group()

def main():
    parser = argparse.ArgumentParser()

    # multiprocessing arguments
    parser.add_argument('-po', '--port', default='9040',
                        type=str)
    parser.add_argument('-di', '--distributed', action='store_true',
                    help='Distributed?')

    # data I/O
    parser.add_argument('-i', '--data_dir', type=str,
                        default='/data/datasets', help='Location for the dataset')
    parser.add_argument('-o', '--save_dir', type=str, default='/data/image_model/',
                        help='Location for parameter checkpoints and samples')
    parser.add_argument('-d', '--dataset', type=str,
                        default='cifar', help='Can be cifar|imagenet64|celeba64')
    parser.add_argument('-p', '--print_every', type=int, default=200,
                        help='how many iterations between print statements')
    parser.add_argument('-si', '--save_interval', type=int, default=10,
                        help='Every how many epochs to write checkpoint/samples?')
    parser.add_argument('-r', '--load_params', type=str, default=None,
                        help='Restore training from previous model checkpoint?')

    # model
    parser.add_argument('-na', '--no_attention', action='store_false', dest='attention',
                    help='Do we drop the attention module?')
    parser.add_argument('-ds', '--no_discrete', action='store_false', dest='discrete',
                    help='Is the model discrete (or continuous)?')
    parser.add_argument('-tc', '--no_time_cond', action='store_false', dest='time_cond',
                    help='Is the model time conditional?')
    parser.add_argument('-cc', '--class_cond', action='store_true',
                    help='Is the model class conditional?')
    parser.add_argument('-es', '--extra_string', type=str, default='',
                        help='Extra string to disambiguate training models')
    parser.add_argument('-sc', '--sde_cls', type=str, default='vpsde',
                        help='SDE class')
    parser.add_argument('-M', '--bs_max', type=float, default=20.,
                        help='beta_max or sigma_max parameter of subVPSDE / VESDE')
    parser.add_argument('-m', '--bs_min', type=float, default=.0001,
                        help='beta_min or sigma_min parameter of subVPSDE / VESDE')
    parser.add_argument('-ts', '--t_scale', type=float,
                        default=999., help='Multiplicative scale of t input to model')
    parser.add_argument('-T', '--T', type=float,
                        default=.025, help='denoise time')
    parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                        help='Number of residual blocks per stage of the model')
    parser.add_argument('-nf', '--nr_filters', type=int, default=160,
                        help='Number of filters to use across the model. Higher = larger model.')
    parser.add_argument('-nm', '--nr_logistic_mix', type=int, default=10,
                        help='Number of logistic components in the mixture. Higher = more flexible model')
    parser.add_argument('-nr', '--nr_resolutions', type=int, default=3,
                        help='Number of downscaling / upscaling resolutions in UNet')
    parser.add_argument('-dr', '--dropout', type=float,
                        default=0.0, help='Dropout')
    parser.add_argument('-l', '--lr', type=float,
                        default=2e-4, help='Base learning rate')
    parser.add_argument('-e', '--lr_decay', type=float, default=0.999998,
                        help='Learning rate decay, applied every step of the optimization')
    parser.add_argument('-em', '--ema_decay', type=float, default=0.99995,
                        help='Learning rate decay, applied every step of the optimization')
    parser.add_argument('-b', '--batch_size', type=int, default=10,
                        help='Batch size during training per GPU')
    parser.add_argument('-x', '--max_epochs', type=int,
                        default=5000, help='How many epochs to run in total?')
    parser.add_argument('-s', '--seed', type=int, default=1,
                        help='Random seed to use')
    args = parser.parse_args()

    print("Distributed?", args.distributed)
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = args.port
    args.world_size = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
    es = '_{}'.format(args.extra_string) if args.extra_string else ''
    args.model_name = '{}_{}_{}_cc{}_lr{}_t{}_T{}_bsM{}_{}_d{}_d{}_a{}{}'.format(
        args.dataset, args.nr_resolutions, args.nr_resnet, args.class_cond, args.lr, args.t_scale, args.T, args.bs_max,
        args.sde_cls, args.discrete, args.dropout, args.attention, es)

    print(args)

    try:
        mp.spawn(train, nprocs=args.world_size, args=(args,))
    except KeyboardInterrupt as e:
        print("Caught KeyboardInterrupt. Killing processes...")
        dist.destroy_process_group()
        print("Processes killed. Ending master process.")
        sys.exit(0)

def train(rank, args):
    gpu = rank
    if args.distributed:
        dist.init_process_group(
            backend='nccl',
            world_size=args.world_size,
            rank=rank
        )

    torch.manual_seed(args.seed)

    # load data
    rescaling     = lambda x : (x - .5) * 2.
    rescaling_inv = lambda x : .5 * x  + .5
    if args.dataset == 'cifar':
        obs = (3, 32, 32)
        c_dim = 10 if args.class_cond else None

        train_dataset = datasets.CIFAR10(
            args.data_dir,
            train=True,
            download=True,
            transform=transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(p=0.5), rescaling])
        )
        test_dataset = datasets.CIFAR10(
            args.data_dir,
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor(), rescaling])
        )
    if args.dataset == 'celeba64':
        obs = (3, 64, 64)
        train_dataset = datasets.CelebA(
            args.data_dir,
            split='train',
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.CenterCrop(140),
                transforms.Resize(64),
                rescaling])
        )
        test_dataset = datasets.CelebA(
            args.data_dir,
            split='test',
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.CenterCrop(140),
                transforms.Resize(64),
                rescaling])
        )
    elif args.dataset == 'imagenet32':
        obs = (3, 32, 32)
        root = '/data/datasets/ImageNet/'
        train_dataset = ImageNet(
            root,
            True,
            transforms.Compose([
                transforms.Resize(32),
                transforms.RandomHorizontalFlip(),
                rescaling,
            ]))

        test_dataset = ImageNet(
            root,
            False,
            transforms.Compose([
                transforms.Resize(32),
                rescaling,
            ]))
    elif args.dataset == 'imagenet64':
        obs = (3, 64, 64)
        root = '/data/datasets/ImageNet/'
        train_dataset = ImageNet(
            root,
            True,
            transforms.Compose([
                transforms.Resize(64),
                transforms.RandomHorizontalFlip(),
                rescaling,
            ]))

        test_dataset = ImageNet(
            root,
            False,
            transforms.Compose([
                transforms.Resize(64),
                rescaling,
            ]))
    elif args.dataset == 'gaussian':
        obs = (3, 8, 8)
        loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                args.data_dir,
                train=True,
                download=True,
                transform=transforms.Compose([transforms.ToTensor(), rescaling, transforms.Resize(8)])
            ), batch_size=1, shuffle=False)
        for x, _ in loader:
            break
        train_dataset = GaussianDataset(obs, n=50000, loc=x.numpy())
        test_dataset = GaussianDataset(obs, n=10000, loc=x.numpy())

    train_loader = build_loader(
        train_dataset, args.world_size, rank, args.batch_size, distributed=args.distributed)

    test_loader = build_loader(
        test_dataset, args.world_size, rank, args.batch_size, distributed=args.distributed)

    # load model
    if args.sde_cls.lower() == 'subvpsde':
        sde_cls = subVPSDE
    elif args.sde_cls.lower() == 'vesde':
        sde_cls = VESDE
    elif args.sde_cls.lower() == 'vpsde':
        sde_cls = VPSDE

    # define model and loss
    model = NCPN(shape=obs, nr_resnet=args.nr_resnet, nr_filters=args.nr_filters, nr_resolutions=args.nr_resolutions,
        nr_logistic_mix=args.nr_logistic_mix, dropout=args.dropout, attn=args.attention, c_dim=c_dim, time_cond=args.time_cond)
    shadow = NCPN(shape=obs, nr_resnet=args.nr_resnet, nr_filters=args.nr_filters, nr_resolutions=args.nr_resolutions,
                nr_logistic_mix=args.nr_logistic_mix, dropout=args.dropout, attn=args.attention, c_dim=c_dim, time_cond=args.time_cond)
    model = EMA(model, shadow, args.ema_decay)

    # define loss function (criterion) and optimizer
    criterion = get_loss_fn(
        sde_cls=sde_cls,
        T=args.T,
        bs_min=args.bs_min,
        bs_max=args.bs_max,
        time_cond=args.time_cond,
        t_scale=args.t_scale,
        discrete=args.discrete,
        class_cond=args.class_cond
    )
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    try:
        model, optimizer, epoch = maybe_load_chkpt(model, optimizer, args.load_params, gpu)
    except Exception as e:
        print("Could not use new checkpoint loader, reverting to old loader...")
        print(e)
        model, optimizer, epoch = maybe_load_chkpt_old(model, optimizer, args.load_params, 'cpu')

    gc.collect()

    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

    # set to gpu
    torch.cuda.set_device(gpu)
    model.cuda(gpu)

    # wrap the model
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[gpu])

    # save directories
    save_tmp_dir = args.save_dir + '/chkpts/{}_tmp.pth'.format(args.model_name)
    save_dir = args.save_dir + '/{}/{}_{}'.format('{}', args.model_name, '{}')
    save_model(model, optimizer, epoch, args.save_dir + '/chkpts/test.pth')

    if gpu == 0:
        print('starting training. model parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad))
        time_ = time.time()
        losses = []
        f = open(save_dir.format('logs', epoch + 1) + '.txt', "a")
        f_loss = open(save_dir.format('logs', epoch + 1) + '_losses.txt', "a")
        def fprint(*args):
            print(*args)
            print(*args, file=f)

    loss = torch.zeros(()).cuda()
    for epoch in range(epoch, args.max_epochs):
        model.train()
        model.cuda()
        loss = torch.zeros_like(loss)
        for batch_idx, (input, c) in enumerate(train_loader):
            input = input.cuda(non_blocking=True)
            c = nn.functional.one_hot(c.cuda(), num_classes=c_dim).float() if args.class_cond else None
            # backward pass
            optimizer.zero_grad()

            loss_ = criterion(input, model, c=c)
            (loss_).backward()
            optimizer.step()

            scheduler.step()
            model.update()

            loss += loss_.item()
            del loss_

            if (batch_idx + 1) % args.print_every == 0:
                loss /= args.print_every
                if args.distributed:
                    dist.all_reduce(loss)
                    loss /= args.world_size
                if gpu == 0:
                    bpds = []
                    with torch.no_grad():
                        for batch_idx, (input, _) in enumerate(train_loader):
                            input = input.cuda(non_blocking=True)
                            bpds.append(criterion(input, model, likelihood_only=True, c=c).item())
                            if batch_idx > 10:
                                break
                    print(np.mean(bpds), file=f_loss)
                    fprint('loss: {:.4f}, bpd: {:.4f}, time: {:.4f}'.format(
                        loss.item(),
                        np.mean(bpds),
                        time.time() - time_))
                loss = torch.zeros_like(loss)
                time_ = time.time()

            if (batch_idx * args.batch_size) > 100000:
                break

        sample_this_step = ((epoch + 1) % args.save_interval) == 0
        model.eval()
        train_loss = loss.item()
        loss = torch.zeros_like(loss)
        for batch_idx, (input,_) in enumerate(test_loader):
            input = input.cuda()
            loss_ = criterion(input, model, likelihood_only=(not sample_this_step), c=c)
            loss += loss_.detach()
            del loss_, input

        loss /= batch_idx
        if args.distributed:
            dist.all_reduce(loss)
            loss /= args.world_size

        if gpu == 0:
            fprint('epoch: {} test loss: {:.4f}, lr: {:.4E}'.format(
                epoch + 1, loss.item(), scheduler.get_last_lr()[0]))

            # save model every epoch to temporary file
            save_model(model, optimizer, epoch, save_tmp_dir)

            if sample_this_step:
                save_model(model, optimizer, epoch, save_dir.format('chkpts', epoch + 1) + '.pth')
                print('sampling...')
                if train_loss < loss.item():
                    print("using train parameters")
                    model.train()
                sample_t = sample(model, obs, args.T, args.time_cond, args.t_scale, c_dim=c_dim)
                sample_t = rescaling_inv(sample_t)
                utils.save_image(sample_t, save_dir.format('images', epoch + 1) + '.png',
                        nrow=5, padding=0)
                print(args)
                del sample_t

def build_loader(dataset, world_size, rank, batch_size, distributed):
    if distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            num_replicas=world_size,
            rank=rank
        )
    else:
        sampler = None
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        sampler=sampler
    )

    return loader

if __name__ == '__main__':
    main()
