import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch.utils.data import Subset

from cifar.sde_lib import VPSDE
from cifar_model import NCSNpp, get_ddpmpp_config, get_model_and_dataset

device_id = 4 # datasetsize:20000, 0adam: dim-8 1eve: dim-32 2adamw: dim-16 0even: dim-12 6adam: dim20 7adamw: dim24 5even: dim28
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu
devices = [4, 5, 6, 7]

num_resblock = 4
dim = 16
index = int(dim / 2)
dataset_size = 20000
nf = 96
data_loss = []
gaussian_loss = []

x_grid = torch.arange(-2, 2, 1e-2).to(device)
x_t = torch.randn((1, 3, dim, dim)).repeat(len(x_grid), 1, 1, 1).to(device)  # (nsamples, nch)
x_t[:, 1, index, index] = x_grid
x_t = x_t.to(device)
save_interval = int(100 * (20000 / dataset_size))*5

prefix = f'dim-{dim}-ds-{dataset_size}-nr-{num_resblock}'
if nf != 128:
    prefix += f'-nf-{nf}'
# prefix += f'-epoch-100'
# prefix += f'-act-tanh'
# prefix += f'-model-reduced'
print(f'Experiment setting: {prefix}')
chkpt_name = f'cifar{prefix}DataParallel_lastep.pth'
score_network, cifar_dset = get_model_and_dataset(prefix, train=False)
score_network = score_network.to(device)
score_network = torch.nn.DataParallel(score_network, device_ids=devices, output_device=devices[0])
stat_dict = torch.load(f'./{chkpt_name}', map_location=device)
score_network.load_state_dict(stat_dict, strict=False)

t_grid = torch.arange(1, 0.8, -1e-2).to(device).unsqueeze(-1)

density_grid = torch.zeros([len(t_grid), len(x_grid)]).to(device)
with torch.no_grad():
    for i, t in enumerate(t_grid):
        pred_score = score_network(x_t, t.repeat(x_t.size(0)) * 999)[:, 1, index, index]
        density_grid[i] = pred_score - x_grid
        density_grid[i] -= torch.mean(density_grid[i])

plt.figure()
x_grid, t_grid, density_grid = x_grid.detach().cpu().numpy(), t_grid.detach().cpu().numpy(), density_grid.detach().cpu().numpy()
plt.xlim(t_grid[0], t_grid[-1])
plt.ylim(x_grid[0], x_grid[-1])
plt.imshow(density_grid.T, extent=[t_grid[0], t_grid[-1], x_grid[0], x_grid[-1]], aspect='auto', origin='lower',
           cmap='viridis', vmax=0.015, vmin=-0.015)
# plt.title(title, fontsize=15)
plt.xticks(fontsize=12, rotation=45)
plt.xlabel(r'$t$', fontsize=15)
plt.ylabel(r"$x_t$", fontsize=15)
plt.yticks(fontsize=12)
plt.rcParams['ytick.labelsize'] = 12
plt.colorbar()
plt.tight_layout()
plt.savefig(f'error_propagation_{chkpt_name}.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()
