
import numpy as np
import importlib as ipb
import matplotlib.pyplot as plt
from argparse import Namespace
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import time
from argparse import Namespace
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Dense_act(nn.Module):
    """A fully connected layer that reshapes outputs to feature maps."""

    def __init__(self, input_dim, output_dim, latent=32):
        super().__init__()
        # self.dense = nn.Linear(input_dim, output_dim)
        self.dense = nn.Sequential(
            nn.Linear(input_dim, latent),
            nn.Tanh(),  # added
            nn.Linear(latent, latent),  # added
            nn.Tanh(),  # added
            nn.Linear(latent, output_dim)
        )

    def forward(self, x):
        return self.dense(x)[..., None, None]

class NCSNUNet_t(nn.Module):
    """U-Net architecture with fc layer at the very end
    Starting from 64 channels instead of 32, and no "additional" layer (just the
    usual, but going from 256->512 instead of 128 -> 256.
    """

    def __init__(self, channels=[64, 128, 256, 512], embed_dim=256):
        """Initialize a time-dependent score-based network.
        Args:
            marginal_prob_std: A function that takes time t and gives the standard
            deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
            channels: The number of channels for feature maps of each resolution.
            embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        self.embed = nn.Sequential(
        nn.Linear(1, embed_dim),
        nn.Tanh(),  # added
        nn.Linear(embed_dim, embed_dim),  # added
        nn.Tanh(),  # added
        nn.Linear(embed_dim, embed_dim)
        )
        # Encoding layers where the resolution decreases
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense_act(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense_act(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense_act(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense_act(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding layers where the resolution increases
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2,
                                            bias=False)
        self.tdense4 = Dense_act(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3,
                                            stride=2, bias=False, output_padding=1)
        self.tdense3 = Dense_act(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3,
                                            stride=2, bias=False, output_padding=1)
        self.tdense2 = Dense_act(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)

        # self.out_conv = nn.Conv2d(channels[0], 1, kernel_size=1)
        img_size = 784
        self.out_fc = nn.Linear(img_size, 1)

        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)

    def forward(self, x, t):
        n = x.size(0)

        # print('in switched score model!')
        t = torch.ones(x.size(0), device=x.device) * t
        # Obtain the Gaussian random feature embedding for t
        t = t.squeeze()
        # TODO: should we take the log of t if doing fourier? note that originally it was the log stdev of marginal
        embed = self.act(self.embed(t.view(-1, 1)))

        # Encoding path
        h1 = self.conv1(x)
        ## Incorporate information from t
        h1 = h1 + self.dense1(embed)

        ## Group normalization
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        h2 = self.conv2(h1)
        h2 = h2 + self.dense2(embed)
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        h3 = self.conv3(h2)
        h3 = h3 + self.dense3(embed)
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        h4 = self.conv4(h3)

        # this is the middle of the unet, incorporating information from t
        temb = self.dense4(embed)
        h4 = h4 + temb
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)  # (64, 256, 2, 2)

        # Decoding path
        h = self.tconv4(h4)
        ## Skip connection from the encoding path
        h += self.tdense4(embed)
        h = self.tgnorm4(h)
        h = self.act(h)
        h = self.tconv3(torch.cat([h, h3], dim=1))
        h += self.tdense3(embed)
        h = self.tgnorm3(h)
        h = self.act(h)
        h = self.tconv2(torch.cat([h, h2], dim=1))
        h += self.tdense2(embed)
        h = self.tgnorm2(h)
        h = self.act(h)
        h = self.tconv1(torch.cat([h, h1], dim=1))

        # do we need some activations here?
        h = h.view(n, -1)
        out = self.out_fc(h)

        return out

def rnet_integral(score_model, x, t, num_int_pts = 1):
    # Runge-Kutta 3/8 Method
    # http://www.mymathlib.com/diffeq/runge-kutta/runge_kutta_3_8.html
    # t here is [t0, t1] for how long to integrate
    # Here, score model: (x,t) -> score \in R, where x is the input and t is the time
    outputs = [] 
    h = t[1] - t[0]
    if num_int_pts > 1:
        h = (t[1] - t[0]) / num_int_pts
    for i in range(num_int_pts):
        t_now = t[0] + i*h
        # print(f'Starting at {t_now} and ending at {t_now + h} with step size {h}.')
        k1 = score_model(x, t_now)
        k2 = score_model(x, t_now + h/3)
        k3 = score_model(x, t_now + 2*h/3)
        k4 = score_model(x, t_now + h)
        if i > 0:
            # This is because we break the integral into smaller pieces, so we need to 
            # add the previous output to the current output for cumulative integration
            outputs.append(outputs[-1] + h/8 * (k1 + 3*k2 + 3*k3 + k4))
        else:
            outputs.append(h/8 * (k1 + 3*k2 + 3*k3 + k4))
    return torch.stack(outputs)

def cont_t_train(train_loader, time_ls, rnet, optimizer, num_int_pts = 1):
    '''
        # train_loader IS RAW MNIST with pixels in [0,1]
        # time_ls = [[t_{k-1}, t_k]], k=1,...,L+1
        # rnet is a continuous time score function
    '''
    softplus = torch.nn.Softplus(beta = 1)
    loss_tot = []
    batch_num = 0
    start_b = time.time()
    bpd_ls = []
    for batch in train_loader:
        optimizer.zero_grad()
        loss_batch = 0
        for i, t in enumerate(time_ls):
            x, y = batch[i], batch[i+1]
            t = torch.tensor(t).to(device)
            output_xy = rnet_integral(rnet, x, t, num_int_pts)
            output_yx = rnet_integral(rnet, y, torch.flip(t, [0]), num_int_pts)
            loss_X = softplus(output_xy[-1]).mean()
            loss_Y = softplus(output_yx[-1]).mean()
            loss_batch += loss_X + loss_Y
        loss_batch.backward()
        optimizer.step()
        loss_tot.append(loss_batch.item())
        batch_num += 1
        ### Viz intermediate results because one epoch can take a long time
        freq = 20
        if batch_num % freq == 0:
            torch.cuda.empty_cache()
            MConv.mem_report()
            print(f'Finish {freq} out of {batch_num}/{len(train_loader)} batches:')
            print(f'Took {(time.time() - start_b)/60:.2f} minutes')
            print(f'Training loss: {loss_batch.item():.2f}')
            start_b = time.time()
            bpd_ = compute_estimated_bpd(rnet, Xtest, time_ls, 
                                         num_int_pts, full = False)
            bpd_ls.append(bpd_)
        # if batch_num % 65 == 0:
        #     fig, ax = plt.subplots(1, 2, figsize = (10, 3))
        #     xaxis = freq*np.arange(1, len(bpd_ls)+1)
        #     idxes = freq*np.arange(len(bpd_ls))
        #     ax[0].plot(xaxis, np.array(loss_tot)[idxes], '-o')
        #     ax[0].set_xlabel('Number of batches trained')
        #     ax[0].set_ylabel('Training loss')
        #     ax[1].plot(xaxis, bpd_ls, '-o')
        #     ax[1].set_xlabel('Number of batches trained')
        #     ax[1].set_ylabel('Estimated BPD')
        #     fig.tight_layout()
        #     plt.show()
        #     plt.close()
    return np.mean(loss_tot)

def save_or_load(save = False, load = True, 
                 filepath = None, score_model = None, 
                 optimizer = None, loss_score = None,
                 estimated_bpd = None):
    import os
    if save:
        save_obj = {'score_model': score_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss_score': loss_score,
                    'estimated_bpd': estimated_bpd}
        torch.save(save_obj, filepath)
    if load and os.path.exists(filepath):
        print(f'### Load rnets to evaluate or resume')
        save_obj = torch.load(filepath)
        score_model.load_state_dict(save_obj['score_model'])
        optimizer.load_state_dict(save_obj['optimizer'])
        loss_score = save_obj['loss_score']
        estimated_bpd = save_obj['estimated_bpd']
        return loss_score, estimated_bpd
    else:
        loss_score, estimated_bpd = [], []
        return loss_score, estimated_bpd


def logit_to_bpd(logit):
    return logit / (np.log(2) * 784)

def compute_estimated_bpd(score_model, p0_samples, all_t, 
                          num_int_pts, full = True):
    indices = range(len(p0_samples))
    if full == False:
        indices = torch.randperm(len(p0_samples))[:500]
    xinput_PQ = p0_samples[indices]
    full_logit = 0
    for t_now in all_t:
        t_now = torch.tensor(t_now).to(device)
        with torch.no_grad():
            logit_PQ = rnet_integral(score_model, xinput_PQ, 
                                     t_now, num_int_pts)[-1]
        print(f'#### Over {[t_now[0].item(), t_now[-1].item()]}')
        print(f'Logit is {logit_PQ.mean().item():.2f}')
        full_logit += logit_PQ
    rnet_bpd = logit_to_bpd(full_logit.mean().item())
    print(f'#### BPD for logit is {rnet_bpd:.2f}')
    return offset + rnet_bpd



if __name__ == '__main__':
    # Score net initialization

    import load_interpolate_hidspace_RQNSF as MConv
    ipb.reload(MConv)
    self, autoEnc, Xtest = MConv.return_self()
    Xtest = (Xtest + 1) / 2
    from IPython.display import clear_output
    clear_output()

    batch_size = 2000 # Large batch size could cause memory issue
    train_loader_raw, test_loader = MConv.load_raw_mnist(batch_size=batch_size)
    num_select = 1
    num_select = len(train_loader_raw) # 1, ..., len(train_loader_raw)
    full_traj_PQ = MConv.raw_loader_to_traj(train_loader_raw, autoEnc, self, num_select)
    print(full_traj_PQ.shape)

    torch.cuda.empty_cache()
    MConv.mem_report()
    offset = 1.12 # RQ_NSF

    score_model = NCSNUNet_t().to(device)
    print(score_model)
    optimizer_score = torch.optim.Adam(score_model.parameters(), lr=1e-3)


    # Time ls [t_{k-1}, t_k], k=1,...,L+1
    t_discretize = torch.linspace(0, 1, len(self.ls_args_CNF)+2).to(device)
    num_int_pts = 1 # Namely, how far to break up the integral [t_{k-1}, t_k]
    all_t = []
    for t0, t1 in zip(t_discretize[:-1], t_discretize[1:]):
        all_t.append([t0.item(), t1.item()])
    print(np.vstack(all_t))

    # Build data loader from full_PQ_traj, which is a list
    bsize = 128
    train_loader_full = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(*full_traj_PQ),
                                                    batch_size=bsize, shuffle=True)
    torch.cuda.empty_cache()
    MConv.mem_report()


    # ### Training

    # In[ ]:


    # Training, with saving option
    num_epochs = 300 
    load = False
    log_niter = 1
    filepath = 'MNIST_score_net_RQNSF.pt'
    #########
    ## Start training
    loss_score, estimated_bpd = save_or_load(save = False, load = load,
                            filepath = filepath, score_model = score_model,
                            optimizer = optimizer_score)

    epoch_now = len(loss_score)
    print(f'### Start training score net at epoch {epoch_now} out of {num_epochs}')
    for enow in range(epoch_now, num_epochs):
        start = time.time()
        loss_score.append(cont_t_train(train_loader_full, 
                                        time_ls = all_t, 
                                        rnet = score_model, 
                                        optimizer = optimizer_score,
                                        num_int_pts = num_int_pts))
        estimated_bpd.append(compute_estimated_bpd(score_model, Xtest, all_t, num_int_pts))
        print(f'Epoch {enow} took time {(time.time()-start)/60} mins')
        print(f'Estimated BPD is {estimated_bpd[-1]:.2f}')

        #### Evaluate and save
        save_or_load(save = True, load = False,
                    filepath = filepath, score_model = score_model,
                    optimizer = optimizer_score, loss_score = loss_score,
                    estimated_bpd = estimated_bpd)

