
import numpy as np
import torch
import torchvision
from torch.utils.data import Subset

from cifar_model import get_model_and_dataset

device_id = 0 # datasetsize:20000, 0adam: dim-8 1eve: dim-32 2adamw: dim-16 0even: dim-12 6adam: dim20 7adamw: dim24 5even: dim28
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu
devices = [device_id]



def calc_loss_data(score_network: torch.nn.Module, x: torch.Tensor, time_step=0.0) -> torch.Tensor:
    # t=0 data t=1 gaussian
    # x: (batch_size, nch) is the training data
    # sample the time
    t = torch.ones((x.shape[0], 1), dtype=x.dtype, device=x.device) * time_step * (1 - 1e-4) + 1e-4
    t = t.unsqueeze(-1).unsqueeze(-1)  # x[N, 1, 1, 1]
    t = t.expand(-1, channel_num, image_size, image_size)
    t = t.to(device)

    # calculate the terms for the posterior log distribution
    int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
    mu_t = x * torch.exp(-0.5 * int_beta)
    var_t = -torch.expm1(-int_beta)
    x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
    grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)
    assert x_t.size() == x.size()
    # calculate the score function
    # score = score_network(x_t, t[:, 0, 0, 0])  # score: (batch_size, nch)
    score = score_network(x_t, t[:, 0, 0, 0] * 999)  # score: (batch_size, nch)
    score = -score / (var_t ** 0.5)

    # calculate the loss function
    loss = (score - grad_log_p) ** 2
    lmbda_t = var_t
    weighted_loss = lmbda_t * loss
    return torch.mean(weighted_loss)


def calc_loss_gaussain(score_network: torch.nn.Module, x: torch.Tensor, time_step=1.0) -> torch.Tensor:
    t = torch.ones((x.shape[0], 1), dtype=x.dtype, device=x.device) * time_step * (1 - 1e-4) + 1e-4
    t = t.unsqueeze(-1).unsqueeze(-1)  # x[N, 1, 1, 1]
    t = t.expand(-1, channel_num, image_size, image_size)
    t = t.to(device)

    # calculate the terms for the posterior log distribution
    int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
    mu_t = x * torch.exp(-0.5 * int_beta)
    var_t = -torch.expm1(-int_beta)
    x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
    # x_t = torch.randn_like(x)
    grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)
    assert x_t.size() == x.size()
    # calculate the score function
    # score = score_network(x_t, t[:, 0, 0, 0])  # score: (batch_size, nch)
    score = score_network(x_t, t[:, 0, 0, 0] * 999)  # score: (batch_size, nch)
    # score = -score / (var_t ** 0.5)

    # calculate the loss function
    loss = (score - x_t) - torch.mean(score - x_t, dim=0, keepdim=True)
    loss = (loss ** 2) #** 0.5
    weighted_loss = loss
    return torch.mean(weighted_loss)

    # score = -score / (var_t ** 0.5)
    # # calculate the loss function
    # loss = (score - grad_log_p) - torch.mean(score - grad_log_p, dim=(1,2,3), keepdim=True)
    # loss = (loss ** 2) ** 0.5
    # weighted_loss = loss
    # return torch.mean(weighted_loss)


num_resblock = 1
dim = 8
dataset_size = 5000
nf = 128
data_loss = []
gaussian_loss = []


# for dataset_size in [10000, 20000, 40000]:
# for dim in [8, 16, 24, 32]:
# for nf in [32, 64, 128]:
# for nf in [4, 8, 12]:
for nf in [64, 96, 128, 256]:
# for nf in [32, 64, 96, 128, 256]:
# for nf in [64, 128]:
# for nf in [32]:
# for nf in [128]:
    save_interval = int(100 * (20000 / dataset_size))*5
    chkpt = [f'ep{i * save_interval}' for i in range(0, 6)]
    chkpt.append('lastep')


    prefix = f'dim-{dim}-ds-{dataset_size}-nr-{num_resblock}'
    if nf != 128:
        prefix += f'-nf-{nf}'
    # prefix += f'-epoch-100'
    prefix += f'-act-tanh'
    prefix += f'-model-reduced'
    print(f'Experiment setting: {prefix}')
    score_network, cifar_dset = get_model_and_dataset(prefix, train=False)

    # show a sample
    import matplotlib.pyplot as plt

    # the first index is for the dataset, the second is for the tuple, the third one is for channel
    print(len(cifar_dset), cifar_dset[0][0].size())
    channel_num = cifar_dset[0][0].size()[0]
    image_size = cifar_dset[0][0].size()[1]

    dloader = torch.utils.data.DataLoader(cifar_dset, batch_size=100, shuffle=True)
    score_network = score_network.to(device)
    score_network = torch.nn.DataParallel(score_network, device_ids=devices)

    chkpt_name = f'cifar{prefix}DataParallel_{chkpt[-1]}.pth'
    stat_dict = torch.load(f'./{chkpt_name}', map_location=device)
    score_network.load_state_dict(stat_dict, strict=True)
    score_network = score_network.to(device)

    losses_data = []
    losses_gaussian = []
    for i in range(20):
        data, _ = next(iter(dloader))
        data = data.to(device)
        loss_data = calc_loss_data(score_network, data, time_step=0.0)
        # loss_gaussian = calc_loss_data(score_network, data, time_step=0.7)
        loss_gaussian = calc_loss_gaussain(score_network, data, time_step=1)
        losses_data.append(loss_data.detach().cpu().numpy())
        losses_gaussian.append(loss_gaussian.detach().cpu().numpy())

    data_loss.append(np.mean(losses_data))
    gaussian_loss.append(np.mean(losses_gaussian))
print(data_loss)
print(gaussian_loss)



