import numpy as np
import torch
from torch.utils.data import Subset

from cifar_model import NCSNpp, get_ddpmpp_config, get_model_and_dataset


def calc_loss_data(score_network: torch.nn.Module, x: torch.Tensor, time_step=0.0) -> torch.Tensor:
    # t=0 data t=1 gaussian
    # x: (batch_size, nch) is the training data
    # sample the time
    t = torch.ones((x.shape[0], 1), dtype=x.dtype, device=x.device) * time_step * (1 - 1e-4) + 1e-4
    t = t.unsqueeze(-1).unsqueeze(-1)  # x[N, 1, 1, 1]
    t = t.expand(-1, channel_num, image_size, image_size)
    t = t.to(device)

    # calculate the terms for the posterior log distribution
    int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
    mu_t = x * torch.exp(-0.5 * int_beta)
    var_t = -torch.expm1(-int_beta)
    x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
    grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)
    assert x_t.size() == x.size()
    # calculate the score function
    # score = score_network(x_t, t[:, 0, 0, 0])  # score: (batch_size, nch)
    score = score_network(x_t, t[:, 0, 0, 0] * 999)  # score: (batch_size, nch)
    score = -score / (var_t ** 0.5)

    # calculate the loss function
    loss = (score - grad_log_p) ** 2
    lmbda_t = var_t
    weighted_loss = lmbda_t * loss
    return torch.mean(weighted_loss)


def calc_loss_gaussain(score_network: torch.nn.Module, x: torch.Tensor, time_step=1.0) -> torch.Tensor:
    t = torch.ones((x.shape[0], 1), dtype=x.dtype, device=x.device) * time_step * (1 - 1e-4) + 1e-4
    t = t.unsqueeze(-1).unsqueeze(-1)
    t = t.expand(-1, channel_num, image_size, image_size)
    t = t.to(device)

    # calculate the terms for the posterior log distribution
    int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
    mu_t = x * torch.exp(-0.5 * int_beta)
    var_t = -torch.expm1(-int_beta)
    x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
    # x_t = torch.randn_like(x)
    grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)
    assert x_t.size() == x.size()
    # calculate the score function
    # score = score_network(x_t, t[:, 0, 0, 0])  # score: (batch_size, nch)
    score = score_network(x_t, t[:, 0, 0, 0] * 999)  # score: (batch_size, nch)
    # score = -score / (var_t ** 0.5)

    # calculate the loss function
    loss = (score - x_t) - torch.mean(score - x_t, dim=0, keepdim=True)
    loss = (loss ** 2) #** 0.5
    weighted_loss = loss
    return torch.mean(weighted_loss)

    # score = -score / (var_t ** 0.5)
    # # calculate the loss function
    # loss = (score - grad_log_p) - torch.mean(score - grad_log_p, dim=(1,2,3), keepdim=True)
    # loss = (loss ** 2) ** 0.5
    # weighted_loss = loss
    # return torch.mean(weighted_loss)

device_id = 7 # datasetsize:20000, 0adam: dim-8 1eve: dim-32 2adamw: dim-16 0even: dim-12 6adam: dim20 7adamw: dim24 5even: dim28
devices_id = [device_id]
#  dim-8-ds-5000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
#  dim-8-ds-10000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
#  dim-8-ds-15000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
#  dim-8-ds-20000-act-tanh-nf-{64, 96, 128, 256}-model-reduced

device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu

num_resblock = 1
dim = 8
dataset_size = 30000
act = 'tanh'
model = 'reduced'
loss_data_all = []
loss_gaussian_all = []
for nf in [512, 1024]:
# for nf in [64, 96, 128, 256]:
    prefix = f'dim-{dim}-ds-{dataset_size}-nr-{num_resblock}'
    if nf != 128:
        prefix += f'-nf-{nf}'
    if act != 'swish':
        prefix += f'-act-{act}'
    if model != None:
        prefix += f'-model-{model}'

    print('training')
    print(f'Experiment setting: {prefix}')
    score_network, cifar_dset = get_model_and_dataset(prefix)
    score_network = score_network.to(device)
    score_network = torch.nn.DataParallel(score_network, device_ids=devices_id)

    chkpt_name = f'cifar{prefix}DataParallel_lastep.pth'
    stat_dict = torch.load(f'./{chkpt_name}', map_location=device)
    score_network.load_state_dict(stat_dict, strict=True)
    score_network = score_network.to(device)

    print(len(cifar_dset), cifar_dset[0][0].size())
    channel_num = cifar_dset[0][0].size()[0]
    image_size = cifar_dset[0][0].size()[1]

    def calc_loss(score_network: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, nch) is the training data

        # sample the time
        t = torch.rand((x.shape[0], 1), dtype=x.dtype, device=x.device) * (1 - 1e-4) + 1e-4
        t = t.unsqueeze(-1).unsqueeze(-1)
        t = t.expand(-1, channel_num, image_size, image_size)
        t = t.to(device)

        # calculate the terms for the posterior log distribution
        int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
        mu_t = x * torch.exp(-0.5 * int_beta)
        var_t = -torch.expm1(-int_beta)
        x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
        grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)
        assert x_t.size() == x.size()
        # calculate the score function
        # score = score_network(x_t, t[:, 0, 0, 0])  # score: (batch_size, nch)
        score = score_network(x_t, t[:, 0, 0, 0] * 999)  # score: (batch_size, nch)
        # score = -score / var_t
        score = -score / (var_t ** 0.5)

        # calculate the loss function
        loss = (score - grad_log_p) ** 2
        lmbda_t = var_t
        weighted_loss = lmbda_t * loss
        return torch.mean(weighted_loss)
        # return torch.mean(loss)

    opt = torch.optim.Adam(score_network.parameters(), lr=2e-4)
    # opt = torch.optim.SGD(score_network.parameters(), lr=2e-5)
    opt.load_state_dict(torch.load(f'./cifarop{prefix}{score_network.__class__.__name__}_lastep.pth'))
    dloader = torch.utils.data.DataLoader(cifar_dset, batch_size=128, shuffle=True)

    # torch.save(score_network.state_dict(), f'./mnist.pth')
    data_iter = iter(dloader)
    current_iteration = 0
    total_iterations = 100

    data_loss = []
    gaussian_loss = []
    while current_iteration < total_iterations:
        # data, _ = next(data_iter)
        try:
            data, _ = next(data_iter)
        except StopIteration:
            print('stopped')
            data_iter = iter(dloader)
            data, _= next(data_iter)

        data = data.to(device)
        # data = data.reshape(data.shape[0], -1).to(device)
        opt.zero_grad()

        # training step
        loss = calc_loss(score_network, data)
        loss.backward()
        opt.step()
        current_iteration += 1
        # print("Saving model...")
        # torch.save(score_network.state_dict(), f'./cifar{prefix}{score_network.__class__.__name__}_lastepplus{i_epoch}.pth')

        _dloader = torch.utils.data.DataLoader(cifar_dset, batch_size=100, shuffle=True)
        with torch.no_grad():
            losses_data = []
            losses_gaussian = []
            for i in range(20):
                data, _ = next(iter(_dloader))
                data = data.to(device)
                loss_data = calc_loss_data(score_network, data, time_step=0.0)
                # loss_gaussian = calc_loss_data(score_network, data, time_step=0.7)
                loss_gaussian = calc_loss_gaussain(score_network, data, time_step=1)
                losses_data.append(loss_data.detach().cpu().numpy())
                losses_gaussian.append(loss_gaussian.detach().cpu().numpy())
            data_loss.append(np.mean(losses_data))
            gaussian_loss.append(np.mean(losses_gaussian))
    print(np.mean(data_loss))
    print(np.mean(gaussian_loss))
    loss_data_all.append(np.mean(data_loss))
    loss_gaussian_all.append(np.mean(gaussian_loss))
print(loss_data_all)
print(loss_gaussian_all)
