import os
import numpy as np
import torch
import networkx as nx
from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import shortest_path
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from tqdm import tqdm
import gdist


class MyDataParallel(torch.nn.DataParallel):
    """
    Allow nn.DataParallel to call model's attributes.
    """

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)


def loss_func_mse(x_input, x_target, mu, logvar, kl_lambda, dataset):
    #u_input = torch.svd(x_input[:, :, ::20])[0] 
    #u_target = torch.svd(x_target[:, :, ::20])[0]
    #rot = torch.matmul(u_target, u_input.transpose(1, 2))
    #x_input = torch.matmul(rot, x_input) 

    #x_input = dataset.unnormalize_coordinates_torch(x_input)
    #x_target = dataset.unnormalize_coordinates_torch(x_target) 

    diff = x_input - x_target

    loss_l2 = torch.sum(diff ** 2, dim=[1, 2]) / x_input.shape[-1]
    loss_l2 = torch.mean(loss_l2)

    loss_kld = -0.5 * torch.sum(1 + logvar - mu ** 2 - torch.exp(logvar))
    loss_kld = loss_kld / x_input.shape[0]

    loss = loss_l2 + kl_lambda * loss_kld
    return loss, loss_l2, loss_kld


def loss_func_l2(x_input, x_target, mu, logvar, kl_lambda, dataset):
    #u_input = torch.svd(x_input[: , :, ::20])[0]
    #u_target = torch.svd(x_target[:, :, ::20])[0]
    #rot = torch.matmul(u_target, u_input.transpose(1, 2))
    #x_input = torch.matmul(rot, x_input) 

    x_input = dataset.unnormalize_coordinates_torch(x_input)
    x_target = dataset.unnormalize_coordinates_torch(x_target)

    diff = x_input - x_target

    loss_l2 = torch.sum(diff ** 2, dim=1) ** 0.5 
    loss_l2_med = torch.median(loss_l2)
    loss_l2 = torch.mean(loss_l2)


    loss_kld = -0.5 * torch.sum(1 + logvar - mu ** 2 - torch.exp(logvar))
    loss_kld = loss_kld / x_input.shape[0]

    loss = loss_l2 + kl_lambda * loss_kld
    return loss, loss_l2, loss_kld, loss_l2_med


def loss_func_l1(x_input, x_target, mu, logvar, kl_lambda, dataset):
    
    x_input = dataset.unnormalize_coordinates_torch(x_input)
    x_target = dataset.unnormalize_coordinates_torch(x_target)
    #diff = x_input - x_target

    #loss_l1 = torch.sum(torch.abs(diff), dim=1) 
    #loss_l1 = torch.mean(loss_l1)

    # loss_l1 = torch.nn.functional.l1_loss(x_input, x_target) #remove smooth l1
    loss_l1 = torch.nn.functional.smooth_l1_loss(x_input * 2, x_target * 2)
    # loss_l1 = torch.nn.functional.mse_loss(x_input, x_target)

    #loss_l1 *= 3 # 3 for xyz coordinates, 3 for division of inputs

    loss_kld = -0.5 * torch.sum(1 + logvar - mu ** 2 - torch.exp(logvar))
    loss_kld = loss_kld / x_input.shape[0]

    if kl_lambda > 0:
        loss = loss_l1 + kl_lambda * loss_kld
    else:
        loss = loss_l1.clone()
    return loss, loss_l1, loss_kld


def loss_func_cfan(x_input, x_target, mu, logvar, kl_lambda, dataset, model_jacob, mu0, mu1, mu2,
                            sigma0=None, sigma1=None, k_jacob=0.0, edge_old=None, **loss_kwargs):

    # x_input = dataset.unnormalize_coordinates_torch(x_input)
    # x_target = dataset.unnormalize_coordinates_torch(x_target)

    loss, loss_l1, loss_kld = loss_func_l1(x_input, x_target, mu, logvar, kl_lambda, dataset)

    # if k_jacob > 0:
    #     loss_jacob, loss_cf = loss_func_jacobian_norm_fo(dataset, model_jacob, mu0, mu1, edge_old=edge_old)
    # else:
    #     loss_jacob = None
    #     loss_cf = None
    # if k_jacob > 0:
    #     loss += k_jacob * loss_jacob
    #     if loss_cf is not None:
    #         if k_jacob > 0:
    #             loss += 0.5 * k_jacob * loss_cf
    loss_jacob = None
    if k_jacob > 0:
        loss_cf = loss_edge(dataset, model_jacob, mu0, mu1, mu2, edge_old=edge_old)
        loss += k_jacob * loss_cf
    else:
        loss_cf = None
    return loss, loss_l1, loss_kld, loss_jacob, loss_cf


def loss_func_jacobian_norm_fo(mesh_data, model, z_c, z_n, sigma_c=None, sigma_n=None, perturb=False, edge_old=None):
    # alpha_c = torch.rand(size=[z_c.shape[0], 1], device=z_c.device)
    alpha_c = torch.ones(size=[z_c.shape[0], 1], device=z_c.device)
    # alpha_n = torch.rand(size=[z_n.shape[0], 1], device=z_n.device)
    alpha_n = torch.ones(size=[z_n.shape[0], 1], device=z_n.device)

    eps_c = alpha_c * (torch.flip(z_c, dims=[0]) - z_c)
    eps_n = alpha_n * (torch.flip(z_n, dims=[0]) - z_n)

    # This perturbation is meant to prevent overfitting on training data for disentanglement penalty.
    if perturb:
        perturb_c = torch.normal(mean=0, std=0.1, size=z_c.shape).to(z_c.device)
        perturb_n = torch.normal(mean=0.0, std=0.1, size=z_n.shape).to(z_n.device)
        with torch.no_grad():
            perturb_c *= torch.std(z_c, dim=0)
            perturb_n *= torch.std(z_n, dim=0)
        eps_c = eps_c + perturb_c
        eps_n = eps_n + perturb_n

    if sigma_c is None:
        z_c_new = z_c + eps_c
        z_n_new = z_n + eps_n
    else:
        eps_c *= sigma_c
        eps_n *= sigma_n
        z_c_new = z_c + eps_c
        z_n_new = z_n + eps_n
    z_c_new = torch.cat([z_c_new, z_n], dim=-1)
    z_n_new = torch.cat([z_c, z_n_new], dim=-1)
    batch_sz = z_c.shape[0]
    z_c_new = z_c_new[:batch_sz // 2]
    z_n_new = z_n_new[batch_sz // 2:]

    z = torch.cat([z_c_new, z_n_new], dim=0)
    pts_g = model.generative_net(z, bn_agree=False)
    pts_g = mesh_data.unnormalize_coordinates_torch(pts_g)
    edge_lens, edge_norms = mesh_data.compute_edge_signal(pts_g)
    # cfan_c = edge_lens[:batch_sz // 2]
    # cfan_n = edge_norms[batch_sz // 2:]
    cfan_c_normals = edge_norms[:batch_sz // 2]
    cfan_n_edges = edge_lens[batch_sz // 2:]

    if edge_old is not None:
        alpha_cf = alpha_c.clone()
        alpha_cf[batch_sz // 2:] = 0.0

        cf_tar = edge_old + alpha_cf * (torch.flip(edge_old, dims=[0]) - edge_old)
        # cf_tar[batch_sz // 2:] = cf_old[batch_sz // 2:]
        cf_tar_trans = mesh_data.unnormalize_edgelens_torch(cf_tar)
        edge_lens_trans = mesh_data.unnormalize_edgelens_torch(edge_lens)
        # loss_cf = torch.mean((cf_tar_trans - edge_lens_trans) ** 2)
        loss_cf = torch.mean(torch.abs(cf_tar_trans - edge_lens_trans))
        if torch.isnan(loss_cf):
            raise ValueError('The edge lengths of the decoding contains NANs.')

    _, cfan_c_n, _, _ = model.inference_net[1](cfan_c_normals, bn_agree=False)
    # loss_n = torch.norm(z_n[:batch_sz // 2] - cfan_c_n, dim=1) ** 2  # / torch.norm(eps_c[:batch_sz // 2], dim=1) ** 2
    loss_n = torch.sum(torch.abs(z_n[:batch_sz // 2] - cfan_c_n), dim=1)
    loss_n = torch.mean(loss_n)
    if torch.isnan(loss_n):
        raise ValueError('Error in normal disentanglement')

    _, cfan_n_c, _, _ = model.inference_net[0](cfan_n_edges, bn_agree=False)
    # loss_c = torch.norm(z_c[batch_sz // 2:] - cfan_n_c, dim=1) ** 2  # / torch.norm(eps_n[batch_sz // 2:], dim=1) ** 2
    loss_c = torch.sum(torch.abs(z_c[batch_sz // 2:] - cfan_n_c), dim=1)
    loss_c = torch.mean(loss_c)
    if torch.isnan(loss_c):
        print('Edge lengths', edge_lens)
        print('Old edge lengths', edge_old)
        print('Edge encoding', cfan_n_c)
        print('Old edge encoding', z_c[batch_sz // 2:])
        raise ValueError('Error in edge length disentanglement')
    # print('Jacobian Penalty losses', loss_n.item(), loss_c.item())

    loss_jacob = loss_c + loss_n
    return loss_jacob, loss_cf


def loss_edge(mesh_data, model, z_cf, z_c, z_n, sigma_c=None, sigma_n=None, perturb=False, edge_old=None):
    # alpha_c = torch.rand(size=[z_c.shape[0], 1], device=z_c.device)
    alpha_c = torch.ones(size=[z_c.shape[0], 1], device=z_c.device)
    # alpha_n = torch.rand(size=[z_n.shape[0], 1], device=z_n.device)
    # alpha_n = torch.ones(size=[z_n.shape[0], 1], device=z_n.device)

    eps_c = alpha_c * (torch.flip(z_c, dims=[0]) - z_c)
    # eps_n = alpha_n * (torch.flip(z_n, dims=[0]) - z_n)

    z_c_new = z_c + eps_c

    if z_cf is None:
        z_new = torch.cat([z_c_new, z_n], dim=-1)
    else:
        z_new = torch.cat([z_cf, z_c_new, z_n], dim=-1)

    pts_g = model.generative_net(z_new, bn_agree=False)
    pts_g = mesh_data.unnormalize_coordinates_torch(pts_g)
    # edge_lens, edge_norms = mesh_data.compute_edge_signal(pts_g)
    edge_lens = mesh_data.compute_ca_signal(pts_g)

    # print('edge_old/ alpha_c', edge_old.shape, alpha_c.shape)
    alpha_c = torch.unsqueeze(alpha_c, -1)
    alpha_c = torch.unsqueeze(alpha_c, -1)
    cf_tar = edge_old + alpha_c * (torch.flip(edge_old, dims=[0]) - edge_old)
    # cf_tar_trans = mesh_data.unnormalize_edgelens_torch(cf_tar)
    # edge_lens_trans = mesh_data.unnormalize_edgelens_torch(edge_lens)
    # print('cf_tar / edge_lens', cf_tar.shape, edge_lens.shape)
    cf_tar_trans = mesh_data.unnormalize_calpha_torch(cf_tar)
    edge_lens_trans = mesh_data.unnormalize_calpha_torch(edge_lens)
    # loss_cf = torch.mean((cf_tar_trans - edge_lens_trans) ** 2)
    loss_cf = torch.sum(torch.abs(cf_tar_trans - edge_lens_trans)) \
              / (np.count_nonzero(mesh_data.ca_bonds[0]) * z_c.shape[0])
    return loss_cf


def train(train_data, model, optimizer, criterion, kl_lambda, device, dataset, loss_kwargs=None, ignore_int_fine=False):
    loss_sum = 0.0
    loss_mse_sum = 0.0
    loss_kld_sum = 0.0
    loss_jacob_sum = 0.0
    loss_cf_sum = 0.0

    model.train()
    for i_batch, (bonds, dihedrals, ca_bonds, labels, pts_target) in enumerate(train_data):
        if not ignore_int_fine:
            bonds = bonds.to(device)
        dihedrals = dihedrals.to(device)
        ca_bonds = ca_bonds.to(device)
        # com = com.to(device)
        signal = [bonds, ca_bonds, dihedrals]
       
        pts_target = pts_target.to(device)
        var_flag = (kl_lambda > 0)
        if not ignore_int_fine:
            start_idx = 0
        else:
            start_idx = 1
        z_list = [None] * len(model.inference_net)
        mu_list = [None] * len(model.inference_net)
        logvar_list = [None] * len(model.inference_net)
        sigma_list = [None] * len(model.inference_net)
        for i in range(start_idx, len(model.inference_net)):
            z_list[i], mu_list[i], logvar_list[i], sigma_list[i] = model.inference_net[i](signal[i], var_flag)
        # z_com = model.com_layer(com)
        # z_list += [z_com]
        z = torch.cat(z_list[start_idx:], dim=-1)
        mu = torch.cat(mu_list[start_idx:], dim=-1)
        logvar = torch.cat(logvar_list[start_idx:], dim=-1)
        sigma = torch.cat(sigma_list[start_idx:], dim=-1)

        if criterion == loss_func_cfan:
            if ignore_int_fine:
                loss_kwargs['mu0'] = None
                loss_kwargs['sigma0'] = None
            else:
                loss_kwargs['mu0'] = mu_list[0]
                loss_kwargs['sigma0'] = sigma_list[0]
            loss_kwargs['mu1'] = mu_list[1]
            loss_kwargs['mu2'] = mu_list[2]
            loss_kwargs['sigma1'] = sigma_list[1]
            loss_kwargs['sigma2'] = sigma_list[2]
            # edge_old = bonds
            edge_old = ca_bonds

        pts_generated = model.generative_net(z)

        optimizer.zero_grad()
        loss, loss_mse, loss_kld, loss_jacob, loss_cf = criterion(pts_generated, pts_target, mu, logvar, kl_lambda,
                                                                  dataset, edge_old=edge_old, **loss_kwargs)
        # optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        loss_sum += loss.item() * ca_bonds.size(0)
        loss_mse_sum += loss_mse.item() * ca_bonds.size(0)
        loss_kld_sum += loss_kld.item() * ca_bonds.size(0)
        if loss_jacob is not None:
            loss_jacob_sum += loss_jacob.item() * ca_bonds.size(0)
        if loss_cf is not None:
            loss_cf_sum += loss_cf.item() * ca_bonds.size(0)
        # if i_batch % 500 == 0:
        #     print('Batch %04d, Loss Total/L1/KLD/JNP/CF %.3E/ %.3E/ %.3E/ %.3E/ %.3E' % (i_batch, loss.item(), loss_mse.item(), loss_kld.item(),
        #                                                                      loss_jacob.item(), loss_cf.item()), flush=True)

    train_results = {'loss': loss_sum / len(train_data.dataset), 'loss_mse': loss_mse_sum / len(train_data.dataset),
                     'loss_kld': loss_kld_sum / len(train_data.dataset)
                     , 'loss_jacobian': loss_jacob_sum / len(train_data.dataset), 'loss_cf': loss_cf_sum / len(train_data.dataset)}
    return train_results


def test(data, model, criterion, kl_lambda, device, dataset, disp=False, loss_kwargs=None, ignore_int_fine=False):
    loss_sum = 0.0
    loss_mse_sum = 0.0
    loss_kld_sum = 0.0
    loss_jacob_sum = 0.0
    loss_cf_sum = 0.0
    loss_val_sum = 0.0
    loss_val_med_sum = 0.0
    model.eval()
    if disp:
        iter_data = tqdm(data, desc='Test Epoch')
    else:
        iter_data = data
    with torch.no_grad():
        for bonds, dihedrals, ca_bonds, labels, pts_target in iter_data:
            if not ignore_int_fine:
                bonds = bonds.to(device)
            dihedrals = dihedrals.to(device)
            ca_bonds = ca_bonds.to(device)
            # com = com.to(device)
            signal = [bonds, ca_bonds, dihedrals]
            pts_target = pts_target.to(device)

            if not ignore_int_fine:
                start_idx = 0
            else:
                start_idx = 1

            z_list = [None] * len(model.inference_net)
            mu_list = [None] * len(model.inference_net)
            logvar_list = [None] * len(model.inference_net)
            sigma_list = [None] * len(model.inference_net)
            for i in range(start_idx, len(model.inference_net)):
                z_list[i], mu_list[i], logvar_list[i], sigma_list[i] = model.inference_net[i](signal[i], var_flag=False)
            # z_com = model.com_layer(com)
            # z_list += [z_com]
            z = torch.cat(z_list[start_idx:], dim=-1)
            mu = torch.cat(mu_list[start_idx:], dim=-1)
            logvar = torch.cat(logvar_list[start_idx:], dim=-1)
            sigma = torch.cat(sigma_list[start_idx:], dim=-1)

            if criterion == loss_func_cfan:
                if ignore_int_fine:
                    loss_kwargs['mu0'] = None
                    loss_kwargs['sigma0'] = None
                else:
                    loss_kwargs['mu0'] = mu_list[0]
                    loss_kwargs['sigma0'] = sigma_list[0]
                loss_kwargs['mu1'] = mu_list[1]
                loss_kwargs['mu2'] = mu_list[2]
                loss_kwargs['sigma1'] = sigma_list[1]
                loss_kwargs['sigma2'] = sigma_list[2]
                # edge_old = bonds
                edge_old = ca_bonds

            pts_generated = model.generative_net(z)

            loss, loss_mse, loss_kld, loss_jacob, loss_cf = criterion(pts_generated, pts_target, mu, logvar, kl_lambda,
                                                                      dataset, edge_old=edge_old, **loss_kwargs)
            #            if criterion == loss_func_edgedist_cfan:
            #                loss.backward()
            _, loss_val, _, loss_val_med = loss_func_l2(pts_generated, pts_target, mu, logvar, kl_lambda, dataset)
            loss_sum += loss.item() * ca_bonds.size(0)
            loss_mse_sum += loss_mse.item() * ca_bonds.size(0)
            loss_kld_sum += loss_kld.item() * ca_bonds.size(0)
            loss_val_sum += loss_val.item() * ca_bonds.size(0)
            loss_val_med_sum += loss_val_med.item() * ca_bonds.size(0)
            if loss_jacob is not None:
                loss_jacob_sum += loss_jacob.item() * ca_bonds.size(0)
            if loss_cf is not None:
                loss_cf_sum += loss_cf.item() * ca_bonds.size(0)
    test_results = {'loss': loss_sum / len(data.dataset), 'loss_mse': loss_mse_sum / len(data.dataset),
                    'loss_kld': loss_kld_sum / len(data.dataset), 'loss_val': loss_val_sum / len(data.dataset),
                    'loss_val_med': loss_val_med_sum / len(data.dataset)
                    , 'loss_jacobian': loss_jacob_sum / len(data.dataset), 'loss_cf': loss_cf_sum / len(data.dataset)}
    return test_results


def set_indices_tensors(mod, update_dict, device):
    for k, va in mod.named_buffers():
        if k.endswith('gather_idx') or k.endswith('scatter_idx'):
            va.data = update_dict[k].clone().to(device)
        if k.endswith('sample_idx0') or k.endswith('sample_idx1') or k.endswith('g_idx_pad') or k.endswith('s_idx_pad'):
            va.data = update_dict[k].clone().to(device)
        if k.endswith('n_pts') or k.endswith('n_pts_out') or k.endswith('n_pts_total'):
            va.data = update_dict[k].clone().to(device)


def train_dual(train_data0, train_data1, model, optimizer, criterion, kl_lambda, device, dataset0, dataset1, idx_dicts,
               loss_kwargs=None, ignore_int_fine=False):
    loss_sum = 0.0
    loss_mse_sum = 0.0
    loss_kld_sum = 0.0
    loss_jacob_sum = 0.0
    loss_cf_sum = 0.0

    model.train()
    for i_batch, ((bonds0, dihedrals0, ca_bonds0, labels0, pts_target0), (bonds1, dihedrals1, ca_bonds1, labels1, pts_target1)) in \
            enumerate(zip(train_data0, train_data1)):

        set_indices_tensors(model, idx_dicts[0], device)

        bonds = bonds0.to(device)
        dihedrals = dihedrals0.to(device)
        ca_bonds = ca_bonds0.to(device)
        signal = [bonds, ca_bonds, dihedrals]
        pts_target = pts_target0.to(device)

        if not ignore_int_fine:
            start_idx = 0
        else:
            start_idx = 1

        var_flag = (kl_lambda > 0)
        z_list = [None] * len(model.inference_net)
        mu_list = [None] * len(model.inference_net)
        logvar_list = [None] * len(model.inference_net)
        sigma_list = [None] * len(model.inference_net)
        for i in range(start_idx, len(model.inference_net)):
            z_list[i], mu_list[i], logvar_list[i], sigma_list[i] = model.inference_net[i](signal[i], var_flag)
        # z_com = model.com_layer(com)
        # z_list += [z_com]
        z = torch.cat(z_list[start_idx:], dim=-1)
        mu = torch.cat(mu_list[start_idx:], dim=-1)
        logvar = torch.cat(logvar_list[start_idx:], dim=-1)
        sigma = torch.cat(sigma_list[start_idx:], dim=-1)

        if criterion == loss_func_cfan:
            if ignore_int_fine:
                loss_kwargs['mu0'] = None
                loss_kwargs['sigma0'] = None
            else:
                loss_kwargs['mu0'] = mu_list[0]
                loss_kwargs['sigma0'] = sigma_list[0]
            loss_kwargs['mu1'] = mu_list[1]
            loss_kwargs['mu2'] = mu_list[2]
            loss_kwargs['sigma1'] = sigma_list[1]
            loss_kwargs['sigma2'] = sigma_list[2]
            # edge_old = bonds
            edge_old = ca_bonds

        pts_generated = model.generative_net(z)

        optimizer.zero_grad()
        loss0, loss_mse0, loss_kld0, loss_jacob0, loss_cf0 = criterion(pts_generated, pts_target, mu, logvar, kl_lambda,
                                                                  dataset0, edge_old=edge_old, **loss_kwargs)
        loss0.backward()  # Need to accumulate gradients before updating indices
        # loss.backward()
        optimizer.step()
        loss_sum += loss0.item() * bonds.size(0)
        loss_mse_sum += loss_mse0.item() * bonds.size(0)
        loss_kld_sum += loss_kld0.item() * bonds.size(0)
        if loss_jacob0 is not None:
            loss_jacob_sum += loss_jacob0.item() * bonds.size(0)
        if loss_cf0 is not None:
            loss_cf_sum += loss_cf0.item() * bonds.size(0)

        set_indices_tensors(model, idx_dicts[1], device)

        bonds = bonds1.to(device)
        dihedrals = dihedrals1.to(device)
        ca_bonds = ca_bonds1.to(device)
        signal = [bonds, ca_bonds, dihedrals]
        pts_target = pts_target1.to(device)

        z_list = [None] * len(model.inference_net)
        mu_list = [None] * len(model.inference_net)
        logvar_list = [None] * len(model.inference_net)
        sigma_list = [None] * len(model.inference_net)
        for i in range(start_idx, len(model.inference_net)):
            z_list[i], mu_list[i], logvar_list[i], sigma_list[i] = model.inference_net[i](signal[i], var_flag)
        # z_com = model.com_layer(com)
        # z_list += [z_com]
        z = torch.cat(z_list[start_idx:], dim=-1)
        mu = torch.cat(mu_list[start_idx:], dim=-1)
        logvar = torch.cat(logvar_list[start_idx:], dim=-1)
        sigma = torch.cat(sigma_list[start_idx:], dim=-1)

        if criterion == loss_func_cfan:
            if ignore_int_fine:
                loss_kwargs['mu0'] = None
                loss_kwargs['sigma0'] = None
            else:
                loss_kwargs['mu0'] = mu_list[0]
                loss_kwargs['sigma0'] = sigma_list[0]
            loss_kwargs['mu1'] = mu_list[1]
            loss_kwargs['mu2'] = mu_list[2]
            loss_kwargs['sigma1'] = sigma_list[1]
            loss_kwargs['sigma2'] = sigma_list[2]
            # edge_old = bonds
            edge_old = ca_bonds

        pts_generated = model.generative_net(z)

        optimizer.zero_grad()
        loss1, loss_mse1, loss_kld1, loss_jacob1, loss_cf1 = criterion(pts_generated, pts_target, mu, logvar, kl_lambda,
                                                                  dataset1, edge_old=edge_old, **loss_kwargs)
        # loss = loss0 + loss1
        loss1.backward()
        optimizer.step()
        loss_sum += loss1.item() * bonds.size(0)
        loss_mse_sum += loss_mse1.item() * bonds.size(0)
        loss_kld_sum += loss_kld1.item() * bonds.size(0)
        if loss_jacob1 is not None:
            loss_jacob_sum += loss_jacob1.item() * bonds.size(0)
        if loss_cf1 is not None:
            loss_cf_sum += loss_cf1.item() * bonds.size(0)

    len_dataset = len(train_data0.dataset) + len(train_data1.dataset)
    train_results = {'loss': loss_sum / len_dataset, 'loss_mse': loss_mse_sum / len_dataset,
                     'loss_kld': loss_kld_sum / len_dataset
        , 'loss_jacobian': loss_jacob_sum / len_dataset, 'loss_cf': loss_cf_sum / len_dataset}
    return train_results


def test_dual(data0, data1, model, criterion, kl_lambda, device, dataset0, dataset1, idx_dicts, disp=False,
              loss_kwargs=None, ignore_int_fine=False):
    loss_sum = 0.0
    loss_mse_sum = 0.0
    loss_kld_sum = 0.0
    loss_jacob_sum = 0.0
    loss_cf_sum = 0.0
    loss_val_sum = 0.0
    loss_val_med_sum = 0.0
    model.eval()

    with torch.no_grad():
        for (bonds0, dihedrals0, ca_bonds0, labels0, pts_target0), (bonds1, dihedrals1, ca_bonds1, labels1, pts_target1) in zip(data0, data1):

            set_indices_tensors(model, idx_dicts[0], device)

            bonds = bonds0.to(device)
            dihedrals = dihedrals0.to(device)
            ca_bonds = ca_bonds0.to(device)
            # com = com.to(device)
            signal = [bonds, ca_bonds, dihedrals]
            pts_target = pts_target0.to(device)

            if not ignore_int_fine:
                start_idx = 0
            else:
                start_idx = 1

            z_list = [None] * len(model.inference_net)
            mu_list = [None] * len(model.inference_net)
            logvar_list = [None] * len(model.inference_net)
            sigma_list = [None] * len(model.inference_net)
            for i in range(start_idx, len(model.inference_net)):
                z_list[i], mu_list[i], logvar_list[i], sigma_list[i] = model.inference_net[i](signal[i], var_flag=False)
            # z_com = model.com_layer(com)
            # z_list += [z_com]
            z = torch.cat(z_list[start_idx:], dim=-1)
            mu = torch.cat(mu_list[start_idx:], dim=-1)
            logvar = torch.cat(logvar_list[start_idx:], dim=-1)
            sigma = torch.cat(sigma_list[start_idx:], dim=-1)

            if criterion == loss_func_cfan:
                if ignore_int_fine:
                    loss_kwargs['mu0'] = None
                    loss_kwargs['sigma0'] = None
                else:
                    loss_kwargs['mu0'] = mu_list[0]
                    loss_kwargs['sigma0'] = sigma_list[0]
                loss_kwargs['mu1'] = mu_list[1]
                loss_kwargs['mu2'] = mu_list[2]
                loss_kwargs['sigma1'] = sigma_list[1]
                loss_kwargs['sigma2'] = sigma_list[2]
                # edge_old = bonds
                edge_old = ca_bonds

            pts_generated = model.generative_net(z)

            loss, loss_mse, loss_kld, loss_jacob, loss_cf = criterion(pts_generated, pts_target, mu, logvar, kl_lambda,
                                                                      dataset0, edge_old=edge_old, **loss_kwargs)
            #            if criterion == loss_func_edgedist_cfan:
            #                loss.backward()
            _, loss_val, _, loss_val_med = loss_func_l2(pts_generated, pts_target, mu, logvar, kl_lambda, dataset0)
            loss_sum += loss.item() * bonds.size(0)
            loss_mse_sum += loss_mse.item() * bonds.size(0)
            loss_kld_sum += loss_kld.item() * bonds.size(0)
            loss_val_sum += loss_val.item() * bonds.size(0)
            loss_val_med_sum += loss_val_med.item() * bonds.size(0)
            if loss_jacob is not None:
                loss_jacob_sum += loss_jacob.item() * bonds.size(0)
            if loss_cf is not None:
                loss_cf_sum += loss_cf.item() * bonds.size(0)

            set_indices_tensors(model, idx_dicts[1], device)

            bonds = bonds1.to(device)
            dihedrals = dihedrals1.to(device)
            ca_bonds = ca_bonds1.to(device)
            # com = com.to(device)
            signal = [bonds, ca_bonds, dihedrals]
            pts_target = pts_target1.to(device)

            z_list = [None] * len(model.inference_net)
            mu_list = [None] * len(model.inference_net)
            logvar_list = [None] * len(model.inference_net)
            sigma_list = [None] * len(model.inference_net)
            for i in range(start_idx, len(model.inference_net)):
                z_list[i], mu_list[i], logvar_list[i], sigma_list[i] = model.inference_net[i](signal[i], var_flag=False)
            # z_com = model.com_layer(com)
            # z_list += [z_com]
            z = torch.cat(z_list[start_idx:], dim=-1)
            mu = torch.cat(mu_list[start_idx:], dim=-1)
            logvar = torch.cat(logvar_list[start_idx:], dim=-1)
            sigma = torch.cat(sigma_list[start_idx:], dim=-1)

            if criterion == loss_func_cfan:
                if ignore_int_fine:
                    loss_kwargs['mu0'] = None
                    loss_kwargs['sigma0'] = None
                else:
                    loss_kwargs['mu0'] = mu_list[0]
                    loss_kwargs['sigma0'] = sigma_list[0]
                loss_kwargs['mu1'] = mu_list[1]
                loss_kwargs['mu2'] = mu_list[2]
                loss_kwargs['sigma1'] = sigma_list[1]
                loss_kwargs['sigma2'] = sigma_list[2]
                # edge_old = bonds
                edge_old = ca_bonds

            pts_generated = model.generative_net(z)

            loss, loss_mse, loss_kld, loss_jacob, loss_cf = criterion(pts_generated, pts_target, mu, logvar, kl_lambda,
                                                                      dataset1, edge_old=edge_old, **loss_kwargs)
            #            if criterion == loss_func_edgedist_cfan:
            #                loss.backward()
            _, loss_val, _, loss_val_med = loss_func_l2(pts_generated, pts_target, mu, logvar, kl_lambda, dataset1)
            loss_sum += loss.item() * bonds.size(0)
            loss_mse_sum += loss_mse.item() * bonds.size(0)
            loss_kld_sum += loss_kld.item() * bonds.size(0)
            loss_val_sum += loss_val.item() * bonds.size(0)
            loss_val_med_sum += loss_val_med.item() * bonds.size(0)
            if loss_jacob is not None:
                loss_jacob_sum += loss_jacob.item() * bonds.size(0)
            if loss_cf is not None:
                loss_cf_sum += loss_cf.item() * bonds.size(0)

    len_dataset = len(data0.dataset) + len(data1.dataset)
    test_results = {'loss': loss_sum / len_dataset, 'loss_mse': loss_mse_sum / len_dataset,
                    'loss_kld': loss_kld_sum / len_dataset, 'loss_val': loss_val_sum / len_dataset,
                    'loss_val_med': loss_val_med_sum / len_dataset
        , 'loss_jacobian': loss_jacob_sum / len_dataset, 'loss_cf': loss_cf_sum / len_dataset}
    return test_results


def save_checkpoint(ckpt_dir, epoch, name='checkpoint', **kwargs):
    state = {
        'epoch': epoch,
    }
    state.update(kwargs)
    filepath = os.path.join(ckpt_dir, '%s-%d.pt' % (name, epoch))
    torch.save(state, filepath)


def compute_distance_mat(csgraph, savedir, dataset, max_distance=None):
   filename = '%s/%s/ref_dist.npy' % (savedir, dataset)
   if os.path.exists(filename):
       print('Loading geodesic distance matrix.')
       dist = np.load(filename)
   else:
       print('Computing geodesic distance matrix.')
       dist = shortest_path(csgraph.to_dense().numpy())
       #np.putmask(dist, dist == 0.0, np.inf)
       np.fill_diagonal(dist, 0.0)
       np.save(filename, dist)
   return dist


def farthest_point_sampling(pts, dist, num_pts, savedir, dataset, seed_idx=0):
    filename = '%s/%s/ref_idx_fps.npy' % (savedir, dataset)
    if os.path.exists(filename):
        print('Loading farthest point sampling point indices.')
        idx_sampled = np.load(filename)
    else:
        distances_downsampled = np.zeros(num_pts)
        idx_sampled = [seed_idx]
        distances_downsampled[0] = 0
        pts_candidates = [i for i in range(pts.shape[0])]
        pts_candidates.remove(seed_idx)
        for i in tqdm(range(1, num_pts), desc='Computing farthest point sampling'):
            dist_sub = dist[pts_candidates]
            dist_sub = dist_sub[:, idx_sampled]
            if i == 1:
                dist_sub = np.expand_dims(dist_sub, axis=-1)
            dist_min = np.amin(dist_sub, axis=1)
            idx = np.argmax(dist_min)
            distances_downsampled[i] = np.max(dist_min)
            idx_sampled.append(pts_candidates[idx])
            pts_candidates.remove(pts_candidates[idx])
        idx_sampled = np.asarray(idx_sampled)
        np.save(filename, idx_sampled)
    return idx_sampled


def create_downsamples(weighted_adj, savedir, dataset, num_levels, blocks=1, set_subpoints=True):
    dist = compute_distance_mat(weighted_adj, savedir, dataset)
    sample = [None] * num_levels * blocks
    #for j in range(blocks): 
    sample[0] = np.arange(weighted_adj.shape[0])
    # num_pts_1 = int(2 ** (np.floor(np.log2(weighted_adj.shape[0])) - 1))
    if set_subpoints:
        num_pts_1 = int(2 ** (np.floor(np.log2(3690)) - 1))  #So that this works fine
        ds_rate = 2.2
    else:
        num_pts_1 = int(2 ** (np.floor(np.log2(weighted_adj.shape[0])) - 1))
        ds_rate = 3     # The models we are transferring to contain an incredibly large number of points
    sample_pts = farthest_point_sampling(weighted_adj, dist, num_pts_1, savedir, dataset)
    for i in range(1, num_levels):
        #for j in range(blocks):    
        sample[i] = sample_pts[:int(len(sample_pts) / ds_rate ** (i - 1))]  # originally 2
        sample[i] = np.sort(sample[i])
    return sample


def get_sample_idx(sample_pts, level, enc_or_dec, blocks=1):
    if enc_or_dec == 'enc':
        sample_gather = sample_pts[level // blocks]
        sample_scatter = sample_pts[(level + 1) // blocks]
    elif enc_or_dec == 'dec':
        sample_gather = sample_pts[(level + 1) // blocks]
        sample_scatter = sample_pts[level // blocks]
    sample_idx = [sample_gather.astype(int), sample_scatter.astype(int)]
    return sample_idx


def determine_kernel_support(idx_source, idx_target, rad, dist, knn_flag=False, k_num=4):
    supp = [None] * len(idx_target)
    f_idx = []
    if not knn_flag:
        for i, idx_t in enumerate(idx_target):
            # dist = gdist.compute_gdist(pts.astype(np.float64), trg.astype(np.int32),
            #                            source_indices=np.array([idx_t], dtype=np.int32), max_distance=rad * 0.9)
            # supp[i] = np.where(np.isfinite(dist))[0]
            # supp[i] = np.intersect1d(supp[i], idx_source)
            supp[i] = np.where(dist[idx_t, idx_source] < (rad))[0]  # eliminated rad * 0.9]
            f_idx_temp = np.array([supp[i], [i] * len(supp[i])]).transpose()
            f_idx.append(f_idx_temp)
        f_idx = np.vstack(f_idx)
    else:
        k_num = np.minimum(k_num, len(idx_source) - 1)
        dist_sub = dist[:, idx_target][idx_source, :]
        # print('Dist_sub shape', dist_sub.shape, k_num)
        knn_i = dist_sub.argpartition(k_num, axis=0)
        # print('knn_i shape', knn_i.shape)
        knn_i = knn_i[:k_num]
        knn_i = np.transpose(knn_i)
        anchor = np.indices(knn_i.shape)[0].flatten()
        neighbors = knn_i.flatten()
        f_idx = np.stack([neighbors, anchor]).transpose()
        supp = [k for k in knn_i]
    # print('kernel supp', f_idx)
    return f_idx, supp


def compute_sparse_mat(row, col, val):
    indices = torch.stack([row, col])
    # row = row.astype(np.int)
    # col = col.astype(np.int)
    # indices = torch.tensor(np.array([row, col], dtype=np.long).transpose(), dtype=torch.long)
    # values = torch.tensor(val, dtype=torch.float)
    weight = torch.sparse.FloatTensor(indices, val, torch.Size([max(row).item() + 1, max(col).item() + 1]))
    return weight
