

from argparse import Namespace
import torch
import matplotlib.pyplot as plt
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# # Helper functions

# ## Data

# ### Raw data

# In[3]:


import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np 


def subset_data(data, frac=1, dataname='MNIST'):
    # Randomly subset a fraction of data from total data
    np.random.seed(1103)
    idx = np.random.choice(len(data), int(frac*len(data)), replace=False)
    data.data = data.data[idx, :]
    if dataname == 'MNIST':
        data.targets = data.targets[idx]
    else:
        data.targets = torch.tensor(data.targets)[idx]
    return data

def load_raw_mnist(batch_size=64):
    train_transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor()
        ])
    dataset = datasets.MNIST('./data',
                    train=True,
                    download=True,
                    transform=train_transform)
    # subset to first 50K examples for train
    train_indices = np.arange(50000)
    train_loader = torch.utils.data.Subset(dataset, train_indices)
    test_loader = torch.utils.data.Subset(dataset, np.arange(50000, 60000))
    train_loader = torch.utils.data.DataLoader(
        train_loader, batch_size,
        shuffle=True, num_workers=2, drop_last=False)
    test_loader = torch.utils.data.DataLoader(
            test_loader, batch_size,
            shuffle=False, num_workers=2,
            drop_last=False)
    return train_loader, test_loader


# ### Augmentation + Encoder

# In[4]:


def autoEnc_pass(input, flow, encode = True):
    '''
    Input:
        input: raw data with same dimension as image
        autoEncoder: autoEncoder model
        encode: if True, return encoder output. If False, return decoder output
    Output:
        output: encoder or decoder output
    '''
    with torch.no_grad():
        if encode:
            assert input.min() >= 0 and input.max() <= 1
            input = input * 256
            output = flow.module.transform_to_noise(input)
        else:
            assert input.shape[1] == 784
            output = flow.module.sample(input, context=None, rescale=True)
        return output


def get_sub_batches(loader, num_select):
    batches = []
    num = 0
    for batch, _ in loader:
        # 0 is placeholder
        batch = batch.to(device)
        batch = batch * 255. / 256. + torch.rand_like(batch).to(device) / 256.
        batches.append([batch ,0])
        num += 1
        if num == num_select:
            break
    return batches


# #### AutoEncoder using the pre-trained RQ-NSF flow

# In[5]:


import os
import sys
sys.path.append(os.path.join('nsf'))
sys.path.append(os.path.join('experiments'))
from nsf.nde import distributions, flows
from nsf.experiments.images import create_transform
import torchvision
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Below from models/ncsn_flow.py
def load_pretrained_flow():
    # name = config.training.z_space_model
    name = 'rq_nsf'
    print('loading flow model: {}'.format(name))
    ckpt_path = os.path.join('flow_ckpts', 'rq_nsf_best.pt')
    print('loading model from checkpoint: {}'.format(ckpt_path))

    # annoying data transforms
    c = 1
    h = w = 28
    spline_params = {
    "apply_unconditional_transform": False,
    "min_bin_height": 0.001,
    "min_bin_width": 0.001,
    "min_derivative": 0.001,
    "num_bins": 8,
    "tail_bound": 3.0
    }
    distribution = distributions.StandardNormal((c * h * w,))
    # TODO (HACK): get rid of hardcoding
    transform = create_transform(c, h, w,
                                levels=2, hidden_channels=64, steps_per_level=8, alpha=0.000001,
                                num_bits=8, preprocessing="realnvp_2alpha", multi_scale=False,
                                actnorm=True, coupling_layer_type="rational_quadratic_spline",
                                spline_params=spline_params,
                                use_resnet=False, num_res_blocks=2, resnet_batchnorm=False, dropout_prob=0.0)

    net = flows.Flow(transform, distribution)
    
    # load checkpoint
    checkpoint = torch.load(ckpt_path, map_location=device)
    net.load_state_dict(checkpoint)
    net = net.to(device)
    net = torch.nn.DataParallel(net)
    return net


# ## Flow Model Training (by block)

# ### CNF model initialization

# In[6]:


import nets_Jacnorm as nets


def default_CNF_structure(args_net):
    '''
    Input:
        specifics for velocity field (e.g., how wide hidden layer...)
    Output:
        output: CNF with velocity field f: R^d -> R^d
    '''
    odefunc = nets.FCnet(args_net)
    CNF = nets.CNF(odefunc).to(device)
    return CNF


# ### CNF/FlowNet forward passes & loss computation

# In[7]:


def FlowNet_forward(xinput, CNF, ls_args_CNF,
                    block_now,
                    reverse = False, test = True,
                    return_full = False):
    '''
    Input: 
        CNF: continuous time CNF model
        x: input data in R^d into the 1ST CNF model
        ls_args_CNF: a list of args_CNF for each CNF model
        reverse: if True, xinput is noise and we want to generate data
        test: if True, divergence not tracked
        return_full: if True, return the full trajectory of z_k and dlogpx
    Output:
        predz, dlogpx: trajectory over h_k and change in log likelihood throughout the CNF
    Comment:
        This is used to 
            Either update raw data when checkpt loaded
            Or make forward pass during data augmentation
        WE NEVER train based on output of this function, so gradient not tracked
    '''
    if block_now == 0:
        return xinput, 0
    else:
        # Push through previous intervals so far
        ls_args_CNF = ls_args_CNF[:block_now]
        with torch.no_grad():
            predz_ls, dlogpx_ls = [], []
            if reverse:
                ls_args_CNF = list(reversed(ls_args_CNF))
            for i, args_CNF in enumerate(ls_args_CNF):
                # This can check if we used the correct args_CNF
                # if reverse:
                #     print(f'##### On interval {len(ls_args_CNF)-i}: [{args_CNF.Tk_1, args_CNF.Tk}], m_k = {args_CNF.num_int_pts}')
                # else:                
                #     print(f'##### On interval {i+1}: [{args_CNF.Tk_1, args_CNF.Tk}], m_k = {args_CNF.num_int_pts}')
                predz, dlogpx, _ = CNF(xinput, args_CNF, 
                                    reverse = reverse, test = test)
                xinput = predz[-1]
                if i == 0:
                    predz_ls.append(predz)
                    dlogpx_ls.append(dlogpx)
                else:
                    # O/W has replication
                    predz_ls.append(predz[1:])
                    dlogpx_ls.append(dlogpx[1:])
            predz_ls = torch.cat(predz_ls, dim=0)
            dlogpx_ls = torch.cat(dlogpx_ls, dim=0)
            if return_full:
                return predz_ls, dlogpx_ls
            else:
                return predz_ls[-1], dlogpx_ls[-1]

def l2_norm_sqr(input, return_full = False):
    '''
        For tensor with shape (N,M1,M2,...),
        We flatten it to be (N,M1*M2*...)
        Then treate it as N vectors to compute l2^2 norm
    '''
    if len(input.size()) > 2:
        norms = 0.5*input.view(input.shape[0], -1).pow(2).sum(axis=1)
    else:
        norms = 0.5*input.pow(2).sum(axis=1)
    if return_full:
            return norms
    else:
        return norms.mean()

def compute_loss(xinput, CNF, args_CNF):
    '''
    Input:
        xinput: model into current CNF
        CNF: current CNF model
        args_CNF: specifics for CNF to make forward/backward passes, such as h_k, m_k...
    Output:
        loss: loss for the current CNF model
    Comment:
        We always track gradient for this function
    '''
    predz, dlogpx, lossJacnorm = CNF(xinput, args_CNF)
    lossW2 = l2_norm_sqr(predz[-1] - predz[0])/(args_CNF.Tk-args_CNF.Tk_1)
    predz, dlogpx, lossJacnorm = predz[-1], dlogpx[-1], lossJacnorm[-1]
    lossV = l2_norm_sqr(predz)
    lossDiv = dlogpx.mean()
    lossJacnorm = lossJacnorm.mean()
    loss = lossW2 + lossV + lossDiv + lossJacnorm
    return loss, lossW2, lossV, lossDiv, lossJacnorm

def loss_at_initialization(loader, CNF, ls_args_CNF, block_now):
    CNF.eval() # This affects whether initial W2 large or not
    loss_ls, lossW2_ls, lossV_ls, lossDiv_ls = [], [], [], []
    lossJacnorm_ls = []
    with torch.no_grad():
        for batch in loader:
            x, _ = batch
            x = autoEnc_pass(x.to(device), autoEnc, encode = True)
            # Push through previous intervals first
            x, _ = FlowNet_forward(x, CNF, ls_args_CNF, block_now-1)
            loss, lossW2, lossV, lossDiv, lossJacnorm = compute_loss(x, CNF, ls_args_CNF[block_now-1])
            loss_ls.append(loss.item())
            lossW2_ls.append(lossW2.item())
            lossV_ls.append(lossV.item())
            lossDiv_ls.append(lossDiv.item())
            lossJacnorm_ls.append(lossJacnorm.item())
    return np.mean(loss_ls), np.mean(lossW2_ls), np.mean(lossV_ls), np.mean(lossDiv_ls), np.mean(lossJacnorm_ls)


# ### Training and saving

# In[8]:


def train_CNF(x, CNF, optimizer, args_CNF):
    '''
    Input:
        CNF: a CNF model
        args_CNF: specifics for CNF to make forward/backward passes, such as h_k, m_k...
        x: input to the block
        optimizer: an optimizer
        epoch: current epoch
    Output:
        loss: loss for the current CNF model
    '''
    CNF.train()
    loss, lossW2, lossV, lossDiv, lossJacnorm = compute_loss(x, CNF, args_CNF)
    loss = args_CNF.lam_W2 * lossW2 + lossV + lossDiv + args_CNF.lam_J * lossJacnorm
    optimizer.zero_grad()
    loss.backward()
    # clip gradient
    torch.nn.utils.clip_grad_norm_(CNF.parameters(), args_CNF.clip_norm)
    optimizer.step()
    return loss, lossW2, lossV, lossDiv, lossJacnorm

    
def store_loss(loss_dict, losses, block_id):
    '''
    Input:
        A dictionary: each key indicates which CNF we store training losses over
        losses: a list of losses
    Output:
        The same dictionary with updated losses
    '''
    if block_id not in loss_dict:
        loss_dict[block_id] = [losses]
    else:
        loss_dict[block_id].append(losses)
    return loss_dict

def save_or_load(self, save = True, filepath = None,
                 load_checkpoint = False):
    '''
    Input:
        If save:
            filepath: path to save
            FlowNet: a list of CNF model (default None)
            list of args_CNF (default None)
            losses (default None)
        Else:
            Load these above.
            ** Be careful, as blocks may NOT have the same architecture **
            ** Hence, just return a list of state dict and load state dict outside **
    '''
    if filepath is None:
        raise ValueError('Please specify a filepath')
    if save:
        # Save FlowNet, ls_args_CNF, losses
        dict_save = {}
        dict_save['params'] = self.CNF.state_dict()
        dict_save['optimizer'] = self.optimizer.state_dict()
        dict_save['ls_args_CNF'] = self.ls_args_CNF
        dict_save['loss_by_block'] = self.loss_by_block
        dict_save['epoch'] = self.epoch_now+1
        torch.save(dict_save, filepath)
    else:
        import os
        if os.path.exists(filepath) and load_checkpoint:
            # Load FlowNet, ls_args_CNF, losses
            dict_save = torch.load(filepath)
            self.ls_args_CNF = dict_save['ls_args_CNF']
            self.optimizer.load_state_dict(dict_save['optimizer'])
            self.CNF.load_state_dict(dict_save['params'])
            self.loss_by_block = dict_save['loss_by_block']
            self.epoch_now = dict_save['epoch']            
        


# ## Visualization

# In[9]:


def display_mult_images(images, rows, cols, figsize = 0.5, show = True):
    fig, ax = plt.subplots(rows, 1, figsize=(int(figsize*cols), int(figsize*rows)))
    num_per_row = cols
    for i in range(rows):
        start = i*num_per_row
        end = (i+1)*num_per_row
        grid_img_gen = torchvision.utils.make_grid(
            images[start:end], nrow=num_per_row)
        grid_img_gen = grid_img_gen.permute(1, 2, 0).detach().cpu().numpy()
        ax[i].imshow(grid_img_gen)
    for a in ax.ravel():
        a.get_yaxis().set_visible(False)
        a.get_xaxis().set_visible(False)
        a.set_aspect('equal')
    fig.tight_layout(h_pad=0.0, w_pad=0.0)
    if show:
        plt.show()
        plt.close()
    # plt.colorbar()
    return fig


# In[10]:


# Visualize training losses
def concat_losses(loss_block):
    keys = list(loss_block.keys())
    if len(keys) == 1:
        return loss_block[keys[0]]
    else:
        return np.concatenate([loss_block[k] for k in keys]) 
    

def plot_losses(ls_all, args):
    titlesize = 20
    fig, ax = plt.subplots(1, 5, figsize=(20, 4))
    errs = np.array(ls_all)
    msize = 3
    ax[0].plot(errs[:, 1].flatten(), '-o', markersize=msize, color='blue')
    ax[0].set_title(r'W2: $W_2^2(f([t_{k-1}, t_k]))/h_k$', fontsize=titlesize)  
    ax[1].plot(errs[:, 2].flatten(), '-o', markersize=msize, color='blue')
    ax[1].set_title(r'V: $V(X(t_k))/2$', fontsize=titlesize)
    ax[2].plot(errs[:, 3].flatten(), '-o', markersize=msize, color='blue')
    ax[2].set_title(r'Div: $-\int_{t_{k-1}}^{t_k} \nabla \cdot f(X(s),s)ds$', fontsize=titlesize)
    ax[3].plot(errs[:, 4].flatten(), '-o', markersize=msize, color='blue')
    ax[3].set_title(r'Jac: $\int_{t_{k-1}}^{t_k} ||\nabla_{X(s)} f(X(s),s)||^2_F ds$', fontsize=titlesize)
    ax[-1].plot(errs[:, 0].flatten(), '-o', markersize=msize, color='blue')
    ax[-1].set_title('Sum of three', fontsize=titlesize)
    fig.suptitle(
        f'Training metrics over {args.epochs} training epochs\n per block over {args.num_blocks} blocks \n each epoch has {args.tot_batches} batches', y=0.98, fontsize=titlesize)
    for a in ax.flatten():
        # Multiply tick label by args.num_batches
        import matplotlib.ticker as ticker
        a.xaxis.set_major_formatter(ticker.FuncFormatter(lambda y, pos: f'{y*args.num_batches_record:.0f}'))
        a.set_xlabel('Num batches trained', fontsize=titlesize)
    fig.tight_layout()
    return fig


# In[11]:


# Check inversion error and images
def check_inv_err(self, nsamples = 500, alpha = 0.9):
    with torch.no_grad():
        Xtest = autoEnc_pass(self.X_test[:nsamples], autoEnc, encode = True)
        Zhat, _ = FlowNet_forward(Xtest, self.CNF, self.ls_args_CNF, self.block_now,
                                    reverse = False, test = True,
                                    return_full = False)
        Xback, _ = FlowNet_forward(Zhat, self.CNF, self.ls_args_CNF, self.block_now,
                                    reverse = True, test = True,
                                    return_full = False)
        abs_err = l2_norm_sqr(Xback-Xtest)
        print(f'--Test absolute MSE ||X-Finv(F(X))|| is {abs_err.item():.2e}')


# In[12]:


# Check forward and generation
def forward_and_gen_MNIST(self, sizes, nrow=4,ncol=15):
    with torch.no_grad():
        Zsamples = torch.randn(nrow*ncol, sizes).to(device)
        X_back, _ = FlowNet_forward(Zsamples, self.CNF, self.ls_args_CNF, self.block_now, 
                                    reverse = True, test = True,
                                    return_full = False)
        X_back = autoEnc_pass(X_back, autoEnc, encode = False)

        print(f'#### Check generation ####')
        fig_Xhat = display_mult_images(X_back, nrow, ncol, figsize = 4, show = self.show)
    return fig_Xhat


# In[13]:


# Plot changes over trajectory
def plot_over_traj(self, figsize = 1.5):
    with torch.no_grad():
        num_fig = 30
        Xtest = autoEnc_pass(self.X_test[:num_fig], autoEnc, encode = True)
        Zhat_ls, _ = FlowNet_forward(Xtest, self.CNF, self.ls_args_CNF,
                                     self.block_now,
                                     return_full = True)
        ids = torch.linspace(0, Zhat_ls.shape[0]-1, self.block_now+1).long()
        Zhat_ls = [Zhat_ls[i] for i in ids]
        Zhat_ls = torch.cat(Zhat_ls, dim=0)
        Zhat_ls = autoEnc_pass(Zhat_ls, autoEnc, encode = False)
        fig_Zhat = display_mult_images(Zhat_ls, self.block_now+1, num_fig, figsize = figsize, show = self.show)
    return fig_Zhat

# Plot W2 over blocks
def plot_W2_movement(self, num_fig = 500):
    with torch.no_grad():
        Xtest = autoEnc_pass(self.X_test[:num_fig], autoEnc, encode = True)
        Zhat_ls, _ = FlowNet_forward(Xtest, self.CNF, self.ls_args_CNF,
                                     self.block_now,
                                     return_full = True)
        ids = torch.linspace(0, Zhat_ls.shape[0]-1, self.block_now+1).long()
        Zhat_ls = Zhat_ls[ids]
        Diff_Zhat = Zhat_ls[1:] - Zhat_ls[:-1]
        W2_sqr = 0.5*Diff_Zhat.view(Diff_Zhat.shape[0], -1).pow(2).sum(dim=1)/num_fig
        print(f'W2 =\n {W2_sqr.cpu().detach().numpy()}')
        plt.plot(range(1,len(W2_sqr)+1), W2_sqr.cpu().detach().numpy(), 'o-')
        plt.title(r'W2(k)=$0.5\mathbb{E}_{x\sim p_{k-1}} ||\int_{t_{k-1}}^{t_k} f(x(s), s;\theta_k)ds||^2$')


# In[51]:


def raw_loader_to_traj(train_loader_raw, autoEnc, self, num_select):
    train_loader = get_sub_batches(train_loader_raw, num_select) # Add noise
    full_traj_PQ = []
    with torch.no_grad():
        for batch in train_loader:
            batch = batch[0]
            # Pass through RQ-NSF
            batch = autoEnc_pass(batch, autoEnc, encode = True)
            # Pass through JKO for trajectory
            block_now = len(self.ls_args_CNF)
            Zhat_ls, _ = FlowNet_forward(batch, self.CNF, self.ls_args_CNF,
                                        block_now,
                                        return_full = True)
            ids = torch.linspace(0, Zhat_ls.shape[0]-1, block_now+1).long()
            Zhat_ls = [Zhat_ls[i] for i in ids]
            Zhat_ls.append(torch.randn_like(Zhat_ls[-1]).to(device))
            # Inverse pass through RQ-NSF
            batch_ls = []
            for Z in Zhat_ls:
                batch_ls.append(autoEnc_pass(Z, autoEnc, encode = False))
            batch_ls = torch.stack(batch_ls)
            full_traj_PQ.append(batch_ls)
    full_traj_PQ = torch.cat(full_traj_PQ, dim=1)
    return full_traj_PQ


# ## Miscellaneous

# In[15]:


def check_param_same(model1, model2):
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        if p1.data.ne(p2.data).sum() > 0:
            return False
    return True


# In[16]:


import GPUtil
def mem_report():
    if device.type == 'cuda':
        GPUs = GPUtil.getGPUs()
        for i, gpu in enumerate(GPUs):
            print('GPU {:d} ... Mem Free: {:.0f}MB / {:.0f}MB | Utilization {:3.0f}%'.format(
                i, gpu.memoryFree, gpu.memoryTotal, gpu.memoryUtil*100))
    import psutil
    # you can calculate percentage of available memory
    available_mem = psutil.virtual_memory().available * 100 / psutil.virtual_memory().total
    print(f'Available CPU memory: {available_mem:.2f}%')


# In[17]:


def get_Mk(CNF, args_CNF, train_loader):
    max_f = 0
    with torch.no_grad():
        for batch in train_loader:
            x, _ = batch
            x = x.to(device)
            f_out, _ = CNF(x, args_CNF)
            largest_norm = f_out.view(f_out.shape[0], -1).norm(dim = 1).max()
            max_f = max(max_f, largest_norm)
    return max_f


# # Load raw data & AutoEncoder

# In[18]:


# Load raw data
batch_size = 1000
frac = 1
selected_digits = None
_, test_loader = load_raw_mnist(batch_size=batch_size)
# Load auto encoder
autoEnc = load_pretrained_flow()
autoEnc.eval()  # no training


# # Initialization of FlowNet hyperparameters

# In[19]:


Xdim_flow = 784 # Input channel to flow model (after autoencoder)
# Define CNF model. The velocity field part
args_net = Namespace(
    Xdim = Xdim_flow,
    hidden_dim_str = '1024-1024-1024',
    activation = 'softplus',
    Mk = 1 # Not used here, was a part of CIFAR10 AE for small initialization
)

# Define arguments for making forward pass
common_args_CNF = Namespace(
    ##### Important ones #####
    int_mtd = 'RK4',
    num_e = 1, # Num Hutch projection
    fix_e_ls = True, # If true, fix e per sample in integration. 
    use_NeuralODE = True, # If true, use adjoint
    div_bf = False,
    ##### Regularization strength
    lam_W2 = 1, # W2 regularization strength in additional to 1/h_k
    lam_J = 0.01, # Jacbian norm regularization strength
    ##### Gradient clipping #####
    clip_norm = float('inf'), # By default, no clip
    ##### Placeholders #####
    netname = '',
    cond_gen = False,
    rtol = 1,
    atol = 1
)
# h_k, m_k for each block (the h schedule)
h0, c, hmax = 0.5, 1.1, 0.75
max_step = 0.25 # Place holder
min_mk, max_mk = 2, 3



# Training arguments
CNF_train_args = Namespace(
    lr = 1e-3,
    weight_decay = 5e-4,
)

filepath = 'JKO_RQNSF_zspace.pth'
args_training = Namespace(
    num_max_blocks = 5,
    epochs = 50, 
    num_batches_switch = 10, # How often to switch between training each block
    filepath = filepath, # Path to save the model
    save_checkpoint = True, # If true, save models after training X blocks
    load_checkpoint = True, # If true, load the model from filepath
)
change_vfield = False
args_training.filepath = filepath

##### Initialize CNF, agrs_CNF, and loss stuff #####
self = Namespace() # Container
self.epoch_now = 0 # Assume train from epoch 0
self.CNF = default_CNF_structure(args_net)
print(self.CNF)
self.ls_args_CNF = []
for i in range(args_training.num_max_blocks):
    args_CNF_now = Namespace(**vars(common_args_CNF))
    hk = min(h0*c**i, hmax)
    args_CNF_now.Tk_1 = 0
    if i > 0:
        args_CNF_now.Tk_1 = np.sum([min(h0*c**i, hmax) for i in range(i)])
    args_CNF_now.Tk = args_CNF_now.Tk_1 + hk
    args_CNF_now.num_int_pts = min(max_mk,max(min_mk,np.ceil(hk/max_step)))
    self.ls_args_CNF.append(args_CNF_now)

self.optimizer = torch.optim.Adam(self.CNF.parameters(), 
                            lr=CNF_train_args.lr, 
                            weight_decay=CNF_train_args.weight_decay)
self.loss_by_block = {}

##### Load from checkpoint #####
save_or_load(self, save = False, filepath = args_training.filepath,
             load_checkpoint = args_training.load_checkpoint)
for i, a in enumerate(self.ls_args_CNF):
    print(f'##### Block {i+1}: h_k = {a.Tk - a.Tk_1}, m_k = {a.num_int_pts}')
print('Done instantiating CNF and CNF args')



def return_self():
    te_loader = get_sub_batches(test_loader, 10) # Assume 1000 batch size
    Xtest = torch.cat([x[0] for x in te_loader], dim = 0).to(device)
    return self, autoEnc, Xtest * 2 - 1