import numpy as np
import torch
import torchvision
from torch.utils.data import Subset

# from layers import default_init as default_initializer
# quit()
from cifar_model import NCSNpp, get_ddpmpp_config, get_model_and_dataset

# datasetsize:20000, 0adam: dim-8 1eve: dim-32 2adamw: dim-16 0even: dim-12 6adam: dim20 7adamw: dim24 5even: dim28
# datasetsize:10000, 1adam: dim-8 6eve: dim-32 1fi: dim-12 7adamw: dim-28 0even:dim-16 2fe: dim-20 3gi: dim-24
# datasetsize:40000, 7adam: dim-32 5adamw: dim-8 5even:dim-12 4fi:dim-16 6fe:dim-20 3gi:dim-24 4adamw:dim-28
# dim-16-nr-4: 5eve: ds-30000 6even: ds-50000 5eve: ds-10000 6even:ds-40000
# dim-16-ds-20000-nr-4: 7fi:nf64 7fe:nf32
# dim-16-ds-10000-nr-4: 1adam:nf32 2adamw:nf64
# dim-16-ds-40000-nr-4: 2adam:nf32 3adamw:nf64
# dim-16-ds-30000-nr-4: 4adam:nf32 5adamw:nf64
# dim-16-ds-50000-nr-4: 7adam:nf32 5adamw:nf64
# dim-16-ds-10000-nr-4: 0adam:nf96
# dim-16-ds-20000-nr-4: 1adamw:nf96
# dim-16-ds-30000-nr-4: 4eve:nf96
# dim-16-ds-40000-nr-4: 7even:nf96

# dim-16-ds-20000: 4adam: nr-{2,6}
# dim-16-ds-10000: 5adamw: nr-{2,6}
# dim-16-ds-40000: 6eve: nr-{2,6}
# 0adamw: dim-8-ds-30000
# 3adam: dim-16-ds-20000
# 0adam: dim-8-ds-10000-epoch-10-nf{32, 64, 128}
# 1adamw: dim-16-ds-20000-epoch-10-nf{32, 64, 128}
# 2eve: dim-16-ds-30000-epoch-10-nf{32, 64, 128}
# 7adam: dim-16-ds-10000-tanh-nf{32, 64, 128}
# 6adamw: dim-16-ds-20000-tanh-nf{32, 64, 128}
# 5eve: dim-16-ds-30000-tanh-nf{32, 64, 128}
# 5adam, eve: dim-16-ds-30000-tanhnf-8-nr-{41}
# 7adamw: dim-8-ds-10000-act-tanh-nf-{128, 256, 512}-model-reduced
# 6even: dim-8-ds-20000-act-tanh-nf-{128, 256, 512}-model-reduced
# 4fi: dim-8-ds-20000-act-tanh-nf-{128, 256, 512}-model-reduced
# 5adam: dim-8-ds-10000-act-tanh-nf-1024-model-reduced
# 6adamw: dim-8-ds-20000-act-tanh-nf-1024-model-reduced
# 7eve: dim-8-ds-30000-act-tanh-nf-1024-model-reduced
# 5even: dim-8-ds-10000-act-tanh-nf-64-model-reduced
# 6fi: dim-8-ds-20000-act-tanh-nf-64-model-reduced
# 7fe: dim-8-ds-30000-act-tanh-nf-64-model-reduced
# 5gi: dim-8-ds-10000-act-tanh-nf-96-model-reduced
# 6fi: dim-8-ds-20000-act-tanh-nf-96-model-reduced
# 7fe: dim-8-ds-30000-act-tanh-nf-96-model-reduced
# 5even: dim-8-ds-10000-act-tanh-nf-32-model-reduced
# 6fi: dim-8-ds-5000-act-tanh-nf-{32, 64, 96, 128, 256}-model-reduced
# 7fe: dim-8-ds-15000-act-tanh-nf-{32, 64, 96, 128, 256}-model-reduced
# 3adam: dim-8-ds-10000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
# 5adamw: dim-8-ds-20000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
# 6eve: dim-8-ds-30000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
# 7even: dim-8-ds-15000-act-tanh-nf-{64, 96, 128, 256}-model-reduced
# 4adam: dim-8-ds-10000-act-tanh-nf-{512, 1024}-model-reduced
# 5adamw: dim-8-ds-20000-act-tanh-nf-{512, 1024}-model-reduced
# 6eve: dim-8-ds-30000-act-tanh-nf-{512, 1024}-model-reduced
# 7even: dim-8-ds-15000-act-tanh-nf-{512, 1024}-model-reduced

device_id = 7
devices_id = [device_id]
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu

# prefix = 'smalldataset'
# prefix = 'smalldatasetsingledim'
# prefix = 'smalldatasetcrop'
# prefix = 'crop'
# prefix = 'singledim'

# num_resblock = 1
# for num_resblock in [2, 6]:
# nf = 128
for nf in [192, 512, 1024]:
    num_resblock = 1
    dim = 8
    dataset_size = 15000
    act = 'tanh'
    model = 'reduced'

    epoch = int(3000 * (20000 / dataset_size))
    # epoch = 100
    save_interval = int(100 * (20000 / dataset_size)) * 5
    # save_interval = int(500 * (20000 / dataset_size))

    prefix = f'dim-{dim}-ds-{dataset_size}-nr-{num_resblock}'


    if nf != 128:
        prefix += f'-nf-{nf}'
    if epoch != int(3000 * (20000 / dataset_size)):
        prefix += f'-epoch-{epoch}'
    if act != 'swish':
        prefix += f'-act-{act}'
    if model != None:
        prefix += f'-model-{model}'

    # epoch = int(epoch / ( 20000 / dataset_size))
    print(f'Experiment setting: {prefix}')
    score_network, cifar_dset = get_model_and_dataset(prefix)

    # show a sample
    import matplotlib.pyplot as plt
    # the first index is for the dataset, the second is for the tuple, the third one is for channel
    print(len(cifar_dset), cifar_dset[0][0].size())
    channel_num = cifar_dset[0][0].size()[0]
    image_size = cifar_dset[0][0].size()[1]

    plt.imshow(cifar_dset[0][0].permute(1,2,0))
    # plt.colorbar()
    plt.show()
    plt.close()

    def calc_loss(score_network: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, nch) is the training data

        # sample the time
        t = torch.rand((x.shape[0], 1), dtype=x.dtype, device=x.device) * (1 - 1e-4) + 1e-4
        t = t.unsqueeze(-1).unsqueeze(-1)
        t = t.expand(-1, channel_num, image_size, image_size)
        t = t.to(device)

        # calculate the terms for the posterior log distribution
        int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
        mu_t = x * torch.exp(-0.5 * int_beta)
        var_t = -torch.expm1(-int_beta)
        x_t = torch.randn_like(x) * var_t ** 0.5 + mu_t
        grad_log_p = -(x_t - mu_t) / var_t  # (batch_size, nch)
        assert x_t.size() == x.size()
        # calculate the score function
        # score = score_network(x_t, t[:, 0, 0, 0])  # score: (batch_size, nch)
        score = score_network(x_t, t[:, 0, 0, 0] * 999)  # score: (batch_size, nch)
        # score = -score / var_t
        score = -score / (var_t ** 0.5)

        # calculate the loss function
        loss = (score - grad_log_p) ** 2
        lmbda_t = var_t
        weighted_loss = lmbda_t * loss
        return torch.mean(weighted_loss)
        # return torch.mean(loss)

    # start the training loop
    import time
    opt = torch.optim.Adam(score_network.parameters(), lr=2e-4)
    # opt = torch.optim.Adam(score_network.parameters(), lr=2e-3)
    # opt = torch.optim.Adam(score_network.parameters(), lr=5e-3)
    dloader = torch.utils.data.DataLoader(cifar_dset, batch_size=128, shuffle=True)
    score_network = score_network.to(device)
    score_network = torch.nn.DataParallel(score_network, device_ids=devices_id)
    # stat_dict = torch.load(f'./cifarnewDataParallel_ep560.pth', map_location=device )
    # score_network.load_state_dict(stat_dict)

    # torch.save(score_network.state_dict(), f'./mnist.pth')
    t0 = time.time()
    for i_epoch in range(epoch):
        total_loss = 0
        for data, _ in dloader:  # we don't need the data class
            data = data.to(device)
            # data = data * 2 - 1
            # data = data.reshape(data.shape[0], -1).to(device)
            opt.zero_grad()

            # training step
            loss = calc_loss(score_network, data)
            loss.backward()
            opt.step()

            # running stats
            total_loss = total_loss + loss.detach().item() * data.shape[0]

        print(f"{i_epoch} ({time.time() - t0}s): {total_loss / len(cifar_dset)}")
        # print the training stats
        # if i_epoch % save_interval == 0 or i_epoch >= (epoch -10):
        if i_epoch % save_interval == 0:
            print("Saving model...")
        # if i_epoch % 10 == 0 and (i_epoch > 2000 or i_epoch == 0):
            torch.save(score_network.state_dict(), f'./cifar{prefix}{score_network.__class__.__name__}_ep{i_epoch}.pth')
            torch.save(opt.state_dict(), f'./cifarop{prefix}{score_network.__class__.__name__}_ep{i_epoch}.pth')
    torch.save(score_network.state_dict(), f'./cifar{prefix}{score_network.__class__.__name__}_lastep.pth')
    torch.save(opt.state_dict(), f'./cifarop{prefix}{score_network.__class__.__name__}_lastep.pth')
