import time
import sys
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from ncpn.utils import *
from ncpn.sde_lib import subVPSDE, VESDE, VPSDE, EPS
from ncpn.model import NCPN
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import save_image

def plt_imgs(sample, title=''):
    x = (sample + 1) / 2
    show_samples(x, title)

def image_grid(x):
    n, size, _, channels = x.shape
    img = x.reshape(n, size, size, channels)
    w = int(np.sqrt(img.shape[0]))
    img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
    return img

def show_samples(x, title=''):
    x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
    img = image_grid(x)
    plt.figure(figsize=(64, 64))
    plt.axis('off')
    plt.imshow(img)
    plt.title(title)
    plt.show()

def score(data, logits, to_idx=None):
    logp = -loss_op(data, logits)
    if data.grad is not None:
        data.grad.data.zero_()
    if to_idx is None:
        to_idx = (data.shape[2], data.shape[3])
    logp[:, :to_idx[0], :to_idx[1]].sum().backward()

    grad = data.grad
    assert grad is not None

    return grad.detach(), logp.detach()

def get_bpd(sample, eps=EPS):
    t = torch.zeros(len(sample)).cuda() + eps
    with torch.no_grad():
        out = model(sample, cond=t)
        logp = loss_op(sample, out)

    return logp.sum() / (np.prod(sample.shape) * np.log(2.))

def idxs_b4(arr, i, j, agg):
    shape = arr.shape
    new_shape = shape[:-2] + (-1,)
    lin_idx = i * obs[2] + j

    return agg(arr.reshape(new_shape)[..., :lin_idx])

sum_func = lambda x: x.sum(dim=-1)
norm_func = lambda x: (x ** 2).sum(dim=(-2, -1))
loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake, per_sample=True, bin_pixels=False)
sample_op = lambda x : sample_from_discretized_mix_logistic(x, 10)

def get_mean_var(x, t, score, sde):
    mean, var = sde.sde(x, t)
    mean = mean - var[:, None, None, None] ** 2 * score
    return mean.detach(), var.detach()

def mala_sample(shape, T, model, sde, N=100, verbose=True,
                t_scale=999., x_sampled=None):
    time_ = time.time()
    n_updates = 0.
    score_fn = lambda x, logits: score(x, logits)
    dt = - 1. / N

    def denoise(x, t, ix=shape[1], iy=shape[2], mala=False):
        x = x.detach().requires_grad_(True)
        z = torch.randn_like(x)
        logits = model(x, t * t_scale)
        grad, logp = score_fn(x, logits)
        mean, var = get_mean_var(x, t, grad, sde)
        tau = var ** 2
        with torch.no_grad():
            drift = mean * dt
            step = drift + var[:, None, None, None] * np.sqrt(-dt) * z
            x_mean = x + drift
            x_prop = x + step

        if mala:
            # obtain proposal score and likelihood
            x_prop = x_prop.detach().requires_grad_(True)
            grad_prop, logp_prop = score_fn(x_prop, model(x_prop, t * t_scale))

            with torch.no_grad():
                mean, _ = get_mean_var(x_prop, t, grad_prop, sde)
                x_backstep_mean = x + mean * dt

                logp_prop_trunc = idxs_b4(logp_prop, ix, iy, agg=sum_func)
                logp_trunc = idxs_b4(logp, ix, iy, agg=sum_func)
                top_norm = idxs_b4((x - x_backstep_mean), ix, i, norm_func)
                bot_norm = idxs_b4((x_prop - x_mean), ix, iy, norm_func)
                prob_ratio = (logp_prop_trunc - logp_trunc).clamp(max=10.).exp()
                transition_ratio = ((-0.25 / dt * top_norm) - (-0.25 / dt * bot_norm)).exp()
                ratio = prob_ratio * transition_ratio
                alpha = torch.minimum(torch.ones_like(ratio), ratio)

                # sample u
                u = torch.rand_like(alpha, device=device)

                if_update = (u <= alpha)[:,None,None,None]
                x += step * if_update
                n_updates = if_update.sum().int().item()

            return x, x_mean, n_updates
        else:
            return x_prop, x_mean, len(x)

    if x_sampled is None:
        x = torch.zeros(shape).cuda().requires_grad_(True)
        t = torch.ones(len(x)).cuda() * T
        for ix in range(x.shape[2]):
            for iy in range(x.shape[3]):
                with torch.no_grad():
                    logits = model(x, t * t_scale)
                    x[..., ix, iy] = sample_op(logits)[..., ix, iy]

            if verbose:
                print('{:.1f}% done with sampling...'.format((ix + 1) / x.shape[2] * 100))
        x_sampled = x.detach().clone()
    else:
        t = torch.ones(len(x_sampled)).cuda() * T
        x = x_sampled.detach().clone()

    if T > EPS:
        ts = np.linspace(T, EPS, int(N * T))
        for i, t in enumerate(ts):
            t = torch.ones(len(x)).cuda() * t
            x, x_mean, nu = denoise(x, t, mala=False)
            if (i + 1) % 20 == 0 and verbose:
                print('{:.1f}% done with denoising'.format(
                    (i + 1) / int(N * T) * 100))
    else:
        x_mean = x_sampled = x

    return x_mean.detach().cpu(), x_sampled.detach().cpu()

parser = argparse.ArgumentParser()

parser.add_argument('-m', '--mala', action='store_true',
                    help='Use mala?')
parser.add_argument('-s', '--save_every', type=int, default=10,
                        help='Save every n batches')
parser.add_argument('-n', '--n', type=int, default=36,
                        help='Number of images to sample')
parser.add_argument('-b', '--batch_size', type=int, default=12,
                        help='Batch size')
parser.add_argument('-N', '--N', type=int, default=5000,
                        help='Number of time steps')
parser.add_argument('-t', '--t', type=float, default=.01,
                        help='Time')
parser.add_argument('-sde', '--sde_cls', type=str, default='vpsde',
                        help='SDE type')
parser.add_argument('-nt', '--no_time', action="store_false", dest='time_cond',
                        help='Time conditional?')
parser.add_argument('-na', '--no_attention', action="store_false", dest='attention',
                        help='Attention?')
parser.add_argument('-bsM', '--bs_max', type=float, default=20.,
                        help='max beta/sigma')
parser.add_argument('-bsm', '--bs_min', type=float, default=.0001,
                        help='min beta/sigma')
parser.add_argument('-dr', '--dropout', type=float, default=.25,
                        help='dropout')
parser.add_argument('-d', '--dim', type=int, default=32,
                        help='Width/Height of dataset')
parser.add_argument('-md', '--model_dir', type=str, default='',
                        help='Where is the model saved?')
parser.add_argument('-sd', '--save_dir', type=str, default='.',
                        help='Where to save generated images?')
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                        help='Number of residual blocks per stage of the model')
parser.add_argument('-nf', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')

args = parser.parse_args()

obs = (3, args.dim, args.dim)

model = NCPN(shape=obs, nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
                nr_logistic_mix=10, time_cond=args.time_cond, attn=args.attention, dropout=args.dropout)
shadow = NCPN(shape=obs, nr_resnet=5, nr_filters=160,
            nr_logistic_mix=10, time_cond=args.time_cond, attn=args.attention, dropout=args.dropout)
model = EMA(model, shadow, 0.9999)

model = model.cuda()

chkpt = torch.load(args.model_dir)
model.model.load_state_dict(chkpt['model_state_dict'])
model.shadow.load_state_dict(chkpt['shadow_state_dict'])
del chkpt
model.eval()
print('model parameters loaded')

# make directory if it doesn't already exist
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

if not args.time_cond:
    model_ = lambda x, t: model(x)
else:
    model_ = model

if args.sde_cls.lower() == 'subvpsde':
    sde_cls = subVPSDE
elif args.sde_cls.lower() == 'vesde':
    sde_cls = VESDE
elif args.sde_cls.lower() == 'vpsde':
    sde_cls = VPSDE

xs = []
sde = sde_cls(args.bs_min, args.bs_max, N=args.N)
for i in range((args.n // args.batch_size) + 1):
    x = mala_sample((args.batch_size,) + obs, args.t, model_, sde, N=args.N, t_scale=999., verbose=False)[0]
    xs.append(x)
    if i % args.save_every == 0:
        torch.save(xs, args.save_dir + '/tmp.pth')
        print("Saved checkpoint. {:.2f}% done.".format((i + 1) / (args.n // args.batch_size + 1) * 100))

x = torch.cat(xs, axis=0)[:args.n]

torch.save(x, args.save_dir + '/images.pth')
plt.figure()
sqrt_n = int(np.sqrt(len(x)))
plt_imgs(x[:sqrt_n ** 2])
plt.savefig(args.save_dir + '/image.png')

# also save individually as .png files
if not os.path.exists(args.save_dir + '/pngs/'):
    os.mkdir(args.save_dir + '/pngs/')

for i, x_ in enumerate(x):
    x_ = (x_ + 1) / 2
    save_image(x_, args.save_dir + '/pngs/{}.png'.format(i))
