import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch.utils.data import Subset

from cifar.sde_lib import VPSDE
from cifar_model import NCSNpp, get_ddpmpp_config, get_model_and_dataset

device_id = 0 # datasetsize:20000, 0adam: dim-8 1eve: dim-32 2adamw: dim-16 0even: dim-12 6adam: dim20 7adamw: dim24 5even: dim28
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu
devices = [device_id]

num_resblock = 1
dim = 8
index = int(dim / 2)
dataset_size = 30000
nf = 128
data_loss = []
gaussian_loss = []

x_t = torch.randn((1, 3, dim, dim), device='cpu').repeat(1000, 1, 1, 1)  # (nsamples, nch)
x_t[:, 1, index, index] = torch.linspace(-2, 2, 1000).to('cpu')
for nf in [64, 96, 128]:
# for nf in [32, 64, 128]:
    save_interval = int(100 * (20000 / dataset_size))*5
    chkpt = [f'ep{i * save_interval}' for i in range(0, 6)]
    chkpt.append('lastep')


    prefix = f'dim-{dim}-ds-{dataset_size}-nr-{num_resblock}'
    if nf != 128:
        prefix += f'-nf-{nf}'
    # prefix += f'-epoch-100'
    prefix += f'-act-tanh'
    prefix += f'-model-reduced'
    print(f'Experiment setting: {prefix}')
    score_network, cifar_dset = get_model_and_dataset(prefix, train=False)

    score_network = score_network.to(device)
    score_network = torch.nn.DataParallel(score_network, device_ids=devices)
    score_network = score_network.module
    score_network = score_network.to('cpu')
    chkpt_name = f'cifar{prefix}DataParallel_{chkpt[-1]}.pth'
    stat_dict = torch.load(f'./{chkpt_name}', map_location='cpu')
    stat_dict = {k.replace('all_module.', '').replace('module.', ''): v for k, v in
                      stat_dict.items()}

    score_network.load_state_dict(stat_dict, strict=True)
    score_network = score_network.to('cpu')

    pred = score_network(x_t, torch.tensor([999]).to('cpu'))[:, 1, index, index]

    plt.figure()
    plt.plot(x_t[:, 1, index, index].detach().cpu(), pred.detach().cpu()
             - x_t[:, 1, index, index].detach().cpu()
             )
    plt.show()
    plt.close()

    _x_t = x_t[:,1,index, index].detach().cpu().numpy()
    pred = pred.detach().cpu().numpy()
    coeffs = np.polyfit(_x_t, pred, deg=1)
    y_pred = np.polyval(coeffs, _x_t)

    # 3. MSE
    mse = np.mean(((pred - y_pred) ** 2) ** 0.5)
    print(mse)
    diff = pred - y_pred
    print(np.mean((diff[1:] - diff[:-1])**2))