import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

# from cifar_model import NCSNpp, models
from sde_lib import VPSDE

device_id = 4
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu

# ep = 'ddpmppdeep_checkpoint_19.pth'
# ep = 'cifarnewnewDataParallel_ep200.pth'
# ep = 'cifarsmalldatasetDataParallel_ep2990.pth'
# ep = 'cifarsmalldatasetsingledimDataParallel_ep2990.pth'
# model = 'smalldatasetsingledim'
model = 'smalldatasetcrop'
ep = 2990
chkpt_name = f'cifar{model}DataParallel_ep{ep}.pth'

# if  'ddpmpp' == ep.split('_')[0]:
#     # score_network = NCSNpp()
#     score_network = models['ddpmpp']
#     score_network = torch.nn.DataParallel(score_network, device_ids=[0,1,2,3])
#     stat_dict = torch.load(f'./{ep}', map_location=device)['model']
#     score_network.load_state_dict(stat_dict, strict=False)
#     score_network = score_network.to(device)
#     def score_fn(score_network):
#         def _score_fn(x, t):
#             log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
#             std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
#             return (-score_network(x,t * 999))/ std[:, None, None, None]
#         return _score_fn
#     score_network = score_fn(score_network)
# elif 'ddpmppdeep' == ep.split('_')[0]:
#     score_network = models['ddpmppdeep']
#     score_network = torch.nn.DataParallel(score_network, device_ids=[0, 1, 2, 3])
#     stat_dict = torch.load(f'./{ep}', map_location=device)['model']
#     score_network.load_state_dict(stat_dict, strict=False)
#     score_network = score_network.to(device)
#     def score_fn(score_network):
#         def _score_fn(x, t):
#             log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
#             std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
#             return (-score_network(x, t * 999)) / std[:, None, None, None]
#         return _score_fn
#     score_network = score_fn(score_network)
# else:
#     score_network = models['ddpmpp']
#     score_network = torch.nn.DataParallel(score_network, device_ids=[4])
#     stat_dict = torch.load(f'./{ep}', map_location=device )
#     score_network.load_state_dict(stat_dict)
#     score_network = score_network.to(device)
#     def score_fn(score_network):
#         def _score_fn(x, t):
#             log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
#             std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
#             return (-score_network(x, t * 999)) / std[:, None, None, None]
#             # if 'newnew' in ep:
#             #     log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
#             #     std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
#             #     return (-score_network(x,t * 999))/ std[:, None, None, None]
#             # else:
#             #     int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
#             #     var_t = -torch.expm1(-int_beta)
#             #     return (-score_network(x, t * 999)) / var_t[:, None, None, None]
#         return _score_fn
#     score_network = score_fn(score_network)

from cifar_model import get_model_and_dataset
score_network, cifar_dset = get_model_and_dataset(model)
score_network = torch.nn.DataParallel(score_network, device_ids=[device_id])
stat_dict = torch.load(f'./{chkpt_name}', map_location=device)
score_network.load_state_dict(stat_dict, strict=False)
score_network = score_network.to(device)
def score_fn(score_network):
    def _score_fn(x, t):
        log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
        std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        return (-score_network(x, t * 999)) / std[:, None, None, None]

    return _score_fn
score_network = score_fn(score_network)

print(len(cifar_dset), cifar_dset[0][0].size())
channel_num = cifar_dset[0][0].size()[0]
image_size = cifar_dset[0][0].size()[1]

def generate_samples_1(score_network, num_samples, batch_size=None):
    sde = VPSDE()
    flow = True
    num_steps = 100
    rsde = sde.reverse(lambda x, t: score_network(x, t.unsqueeze(-1).expand(x.shape[0])), probability_flow=flow)
    # device = next(score_network.parameters()).device
    time_pts = torch.linspace(1, 0, num_steps, device=device)  # (ntime_pts,)
    gaussian = torch.randn((20, channel_num, image_size, image_size), device=device)
    x_t = gaussian
    with torch.no_grad():
        pbar = tqdm(total=len(time_pts) - 1, leave=True)
        for i in range(len(time_pts)-1):
            # reverse sde # 25.95
            t = time_pts[i] #.unsqueeze(-1)
            dt = time_pts[i + 1] - t
            # t = t.expand(x_t.shape[0])
            drift, diffusion = rsde.sde(x_t, t)
            # euler-maruyama step
            x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5
            pbar.update(1)
        pbar.close()
    return x_t

samples = generate_samples_1(score_network, 20).detach() #.reshape(-1, input_size[0], input_size[1])
print(samples.shape)

images = samples.cpu().detach()
images = images.permute(0, 2, 3, 1)
images = (images - images.min()) / (images.max() - images.min())

fig, axes = plt.subplots(3, 7, figsize=(14, 6))

for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i])
        ax.axis('off')
    else:
        ax.axis('off')

plt.tight_layout()
plt.show()