import torch
import time
import yaml
from torch.utils.data import DataLoader
from graph_reconstruction.buildingblocks import *
import neptune
import matplotlib.pyplot as plt
import io
from datasets import load_dataset
from graph_reconstruction.build_molecule import *
from utilities import *
from utils_normalizations import normalize_adjacency, normalize_features
from filtering_deg1_gpu import deg1_bbs
from metrics import evaluate_metrics
from rdkit.Chem import Draw
from baseline.utils import draw_atom
from graph_reconstruction.utils import get_model
from argparse import ArgumentParser
import warnings
warnings.filterwarnings("ignore")
parser = ArgumentParser()
parser.add_argument('config_path', type=str)
parser.add_argument('--neptune', type=str, default=None, required=False)
parser.add_argument('--config_eval', type=str, default='graph_attack_config_eval.yaml', required=False)
args = parser.parse_args()

lists = possible_feature_values()
tensors = []
indices = []

pairs = type_and_mass_pairs()

for j in range(1,len(lists)):
    n = len(lists[j])
    if j!=5:
        new_list = [torch.tensor([1 if i == j else 0 for i in range(n)]) for j in range(n)]
        tensors.append(torch.stack(new_list))
    else:
        tensors.append(type_and_mass_pairs())
    indices.append(torch.tensor([i for i in range(n)]))
all_combinations_indices = torch.cartesian_prod(*indices)

batch_size = 100000
list_all = []

for batch_ind in range(int((all_combinations_indices.shape[0]+1)/batch_size)+1):
    start = batch_ind*batch_size
    end = min(batch_ind*batch_size + batch_size-1, all_combinations_indices.shape[0]-1)
    all_batch_combinations = torch.cat([tensors[i][all_combinations_indices[start:end+1,i]] for i in range(7)], dim=1)
    list_all.append(all_batch_combinations)

individual_features = torch.cat(list_all, dim=0)

individual_features = torch.cat((individual_features[:, 22:74], individual_features[:, 0:22], individual_features[:, 74:140]), dim=1).cuda()

with open(args.config_path, 'r') as file:
    config = yaml.safe_load(file)
          
model_args = config['model_args']
train_args = config['train_args']
dataset = config['data_args']['dataset']
is_gc = config['model_args']['graph_class']
use_neptune = args.neptune is not None

if 'neptune_key' in config and use_neptune:
    neptune_key = config['neptune_key']

main_name = 'ours' if config['train_args']['federated_optimizer'] == 'FedSGD' else 'FedAvg'
run_name = f"{main_name}_{dataset.split('/')[-1]}_l{config['model_args']['n_layers_gcn']}_d{config['model_args']['node_embedding_dim']}_{config['model_args']['act']}"

normalize_fts = config['model_args']['normalize_features']
normalize_adj = config['model_args']['normalize_adjacency']
tol1 = config['attack_args']['tol_1']
tol2 = config['attack_args']['tol_2']
tol3 = config['attack_args']['tol_3']


if normalize_fts:
    individual_features = normalize_features(individual_features, dataset)

with open(args.config_eval, 'r') as file:
        config = yaml.safe_load(file)
          
model_args_eval = config['model_args']


def filter_bbs(batch):
    
    gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
    label = batch['target']

    global gt_fms
    gt_fms = normalize_features(torch.tensor(get_X(gt_mol, feature_onehot_encoding=True)).cuda(), dataset = dataset)
    gt_fms.requires_grad_(True)
    
    global gt_ls
    if is_gc:
        if label==1:
            gt_ls = torch.tensor([0.,1.]).cuda()
        else:
            gt_ls = torch.tensor([1.,0.]).cuda()
    else:
        gt_ls = torch.zeros((gt_fms.shape[0], 2))
        true_idxs = torch.randint(0, 2, size = (gt_fms.shape[0],))
        gt_ls[torch.arange(gt_fms.shape[0]), true_idxs] = 1.
        gt_ls = gt_ls.cuda()

    
    num_nodes = gt_fms.shape[0]

    global gt_ams
    gt_ams = torch.tensor(get_A(gt_mol)).cuda()
    
    num_nodes = gt_ams.shape[0]
    
    adj_forward = normalize_adjacency(gt_ams)
    
    global gradient
    if train_args['federated_optimizer'] == 'FedSGD':
        gcn_output = model.gcn(gt_fms, adj_forward).cuda()
        
        logits = model.readout(gt_fms, gcn_output).cuda()
        gcn_output.requires_grad_(True)
        readout_input = torch.cat((gt_fms, gcn_output), dim=1)
        loss = criterion(logits, gt_ls).cuda()
        loss = loss.sum() / loss.numel()         
        
        gradient = torch.autograd.grad(loss, model.parameters())
    else:
        gradient = compute_grads_fed_avg(model, criterion, gt_fms, adj_forward, gt_ls, avg_epochs=train_args['epochs'], avg_lr=train_args['lr'])
    global gradient_GCN_1
    gradient_GCN_1 = gradient[0].t()
    global gradient_GCN_2
    gradient_GCN_2 = gradient[1].t()
    if model.gcn.n_layers == 2:
        gradient_readot_layer_1 = gradient[2]
    else:
        gradient_readot_layer_1 = gradient[2].t()
    
    if use_neptune:
        run["ranks/rank GCN 1"].append(torch.linalg.matrix_rank(gradient_GCN_1, tol=tol1))
        run["ranks/rank GCN 2"].append(torch.linalg.matrix_rank(gradient_GCN_2, tol=tol2))
        run["ranks/rank readout"].append(torch.linalg.matrix_rank(gradient_readot_layer_1, tol=tol3))
        run["num nodes"].append(num_nodes)
    
    global max_rank
    max_rank = max([torch.linalg.matrix_rank(gradient_GCN_1, tol=tol1), torch.linalg.matrix_rank(gradient_GCN_2, tol=tol2), torch.linalg.matrix_rank(gradient_readot_layer_1, tol=tol3)])
    with torch.no_grad():
        gt_deg1_bbs = get_gt_37_deg1(gt_fms, gt_ams)
        
        gt_deg2_bbs = get_gt_37_deg2(gt_fms, gt_ams)
        
        global survived
        survived = True
        global X_cuda_filtered
        marker, X_cuda_filtered = deg1_bbs(batch, individual_features, gradient, model, criterion, run, gt_ams, tol1=tol1, tol2=tol2, gt_ls = gt_ls if not is_gc else None, dataset=dataset)

        if not marker:
            return False, X_cuda_filtered, None, None, None
        
        all_bbs = X_cuda_filtered.shape[0]
        
        row_sums = X_cuda_filtered.sum(dim=2)
        non_zero_rows = row_sums!=0
        X_cuda_sizes_filtered = non_zero_rows.sum(dim=1)
        
        sliced_tensor = X_cuda_filtered[:, 1:, :]
        relevant_columns = sliced_tensor[:, :, 60:65]
        has_one = relevant_columns > 0
        has_one_in_any_column = has_one.any(dim=2)
        X_dangling_bits_filtered = has_one_in_any_column.sum(dim=1)
                
        X_dangling_bits_filtered, sort_indices = X_dangling_bits_filtered.sort(dim=0) 
        X_cuda_filtered = X_cuda_filtered[sort_indices.squeeze(), :, :] 
        X_cuda_sizes_filtered = X_cuda_sizes_filtered[sort_indices.squeeze()] 

        mini_B = 100
        
        global list_tensors_pos
        list_tensors_pos = [torch.zeros((all_bbs, all_bbs), dtype=torch.int64).cuda() for _ in range(7)]
        
        start_time = time.time()
        
        for pos in range(1,7):
        
            sum1=0
        
            first = True
            for k in range( (all_bbs + mini_B - 1) // mini_B ):
                st = k * mini_B
                en = (k + 1) * mini_B
                en = min( en, all_bbs )
                i = torch.arange( st, en )
                j = torch.arange( all_bbs )
                i_idx, j_idx = torch.meshgrid(i,j)
                i_idx = i_idx.reshape(-1)
                j_idx = j_idx.reshape(-1)
                
                val_pos_1 = are_compatible_pos(X_cuda_filtered[i_idx],X_cuda_sizes_filtered[i_idx],X_cuda_filtered[j_idx], X_cuda_sizes_filtered[j_idx], pos)
                list_tensors_pos[pos][i_idx,j_idx] = val_pos_1
                            
                val = (val_pos_1 > 0).int()
                sum1+=torch.sum(val_pos_1>0)
            
        list_tensors_pos = torch.stack([list_tensors_pos[i] for i in range(0,7)])
        
        del val_pos_1        
            
        startings = []
        endings = []
        
        total_deg2 = 0
        
        xs_db_list = [ [] ]
        yes_stacked_db_list = [ [] ]
        
        list_deg2s = []
        all0s = torch.where(X_dangling_bits_filtered==0)
        list_deg2s.append(X_cuda_filtered[all0s])
        
        for num_dangling in range(1,7):
            
            all = torch.where(X_dangling_bits_filtered==num_dangling)
            
            if all[0].shape[0]>=1:
                start = all[0][0]
                end = all[0][-1]
            else:
                xs_db_list.append(torch.empty(0))
                yes_stacked_db_list.append(torch.empty(0))
                continue
            
            xs_db_tensor = torch.full((1,1),-1).cuda()
            ys_db_stacked = torch.full((1,num_dangling), -1).cuda()
                        
            for i in range(start, end+1):
                list_wheres = []
                all_possible = True
                for d in range(1, num_dangling+1):
                    list_wheres.append(torch.where(list_tensors_pos[d][i])[0])
                    if torch.where(list_tensors_pos[d][i])[0].shape[0]==0:
                        all_possible = False
                        break
                    assert(torch.where(list_tensors_pos[d][i])[0].shape[0]>0)
                
                if all_possible is False:
                    continue
                
                tuple_i_hope = torch.meshgrid(*list_wheres)
                stacked_tuple = torch.stack(tuple_i_hope, dim=0)
                tensors_meshgrided = stacked_tuple.T
                tensors_meshgrided = tensors_meshgrided.reshape(-1,num_dangling)
                
                sorted_rows, _ = torch.sort(tensors_meshgrided, dim=1)
                
                _, indices = torch.unique(sorted_rows, return_inverse=True, dim=0)
                list_indices = list(indices)
                unique_indices = list(set([list_indices.index(x) for x in set(list_indices)]))
                
                reduced_tensor_meshgrided = tensors_meshgrided[unique_indices]
                num_comb = reduced_tensor_meshgrided.shape[0]
                
                xs_db_tensor = torch.cat((xs_db_tensor, torch.full((num_comb,1), i).cuda()), dim=0)
                ys_db_stacked = torch.cat((ys_db_stacked, reduced_tensor_meshgrided), dim=0)
                
                assert(ys_db_stacked.shape[1]==num_dangling)
                
                total_deg2 += reduced_tensor_meshgrided.shape[0]
                        
            assert(ys_db_stacked.shape[1]==num_dangling)
            xs_db_list.append(xs_db_tensor[1:])
            yes_stacked_db_list.append(ys_db_stacked[1:])
                
        if use_neptune:
            run["filtering/deg2/total deg 2 built"].append(total_deg2)
        
        if total_deg2 == 0:
            if use_neptune:
                run["fails/somehow not a single deg 2"].append(1)
            return None, None, None, None, None
        
        end_time = time.time()
        
        timing = end_time-start_time
        
        mini_B1 = 10000
        
        for num_dangling in range(1,7):
            xs_idxs = xs_db_list[num_dangling]
            ys_idxs = yes_stacked_db_list[num_dangling]
            total = xs_idxs.shape[0]
            assert(xs_idxs.shape[0] == ys_idxs.shape[0])
            
            for k in range( (total + mini_B1 - 1) // mini_B1 ):
                st = k * mini_B1
                en = (k + 1) * mini_B1
                en = min( en, total )
                                
                list_deg2s.append(features_combined_ultimate(xs_idxs[st:en], ys_idxs[st:en], X_cuda_filtered, list_tensors_pos))
        
        
        list_deg2s = torch.cat(list_deg2s, dim=0)

        list_deg2s = torch.nan_to_num(list_deg2s, nan=0.0)
        
        if model.gcn.n_layers == 2:
            output_2nd_GCN = torch.cat((list_deg2s[:,0], get_AX_deg2(list_deg2s, gcn_model)[:,0]), dim=1)
        else:
            output_2nd_GCN = gcn_model.act(get_AX_deg2(list_deg2s, gcn_model)[:,0])
            
        estimate_B, R_K_norm = get_layer_decomp(gradient_readot_layer_1.double().cpu(), tol=tol3)

        diffs = check_if_in_span(R_K_norm.double().cpu(), output_2nd_GCN.double().cpu()).double()
        
        static_threshold_deg2 = 1e-2
        
        if use_neptune:
            run["static thresholds/deg 2 threshold"] = static_threshold_deg2
        
        mask_passed = torch.zeros(list_deg2s.shape[0], dtype=torch.bool)
        mask_passed[diffs<static_threshold_deg2] = True
        
        for ind in range(list_deg2s.shape[0]):
            bb = list_deg2s[ind]
            degs = get_degree(bb[7:37])
            dangling_bits = torch.sum(degs>1)
            if dangling_bits!=0:
                continue
            degrees = get_degree(bb.unsqueeze(0)) + 1
            degrees = torch.where(degrees==1, torch.tensor(0.0), degrees)
            result = torch.bmm(degrees.unsqueeze(2), degrees.unsqueeze(1))
            result = torch.pow(result, -0.5)
            result = torch.where(torch.isinf(result), torch.tensor(0.0), result)
            universal_adj_37 = torch.zeros((37,37)).cuda()
            for i in range(7):
                universal_adj_37[0,i] = 1
                universal_adj_37[i,0] = 1
            for i in range(1,7):
                universal_adj_37[i,5*i+2:5*i+7] = 1
                universal_adj_37[5*i+2:5*i+7,i] = 1
            for i in range(37):
                universal_adj_37[i,i] = 1
            all_adj = result * universal_adj_37
            
            nonzero_rows = torch.nonzero( torch.any(all_adj.squeeze(0) > 0, dim=1) ).squeeze()
            all_adj = all_adj.squeeze(0)
            all_adj = all_adj[nonzero_rows]
            all_adj = all_adj.T
            all_adj = all_adj[nonzero_rows]
            
            bb_reduced = bb[nonzero_rows]
            
            if bb_reduced.dim()==1:
                bb_reduced = bb_reduced.unsqueeze(0)
                all_adj = all_adj.unsqueeze(0).unsqueeze(0)
            
            if graph_gradient_dist(all_adj, bb_reduced, gradient, model, criterion, gt_fms, gt_ams, gt_ls if not is_gc else None) < 1e-6:
                return False, bb, None, None, None 
            
        
        list_passed = torch.nonzero(mask_passed).squeeze(1).tolist()
        
        diffs_passed = diffs[mask_passed]
        list_deg2s_passed = list_deg2s[mask_passed]
        
        final_deg2s = []
        final_probs = []
        filtered_bbs = []
        all_deg2s = []
        all_probs = []

        for i, bb in enumerate(list_deg2s_passed):
            bb_died = False
            
            bb_prob = 0.
            num_average = 1
            
            for nb in range(7,37):
                
                if torch.sum(bb[nb])==0:
                    break
                
                cb = (nb-2) // 5
                
                target_central = torch.cat((bb[nb].unsqueeze(0), bb[cb].unsqueeze(0), bb[0].unsqueeze(0), bb[5*cb+2:nb], bb[nb+1:5*cb+7]), dim=0)
                
                to_sort = torch.cat(( get_degree(target_central[2:].unsqueeze(0)).T, target_central[2:]), dim=1)
                rows_as_lists = to_sort.tolist()
                sorted_indices = sorted(range(len(rows_as_lists)), key=lambda i: rows_as_lists[i], reverse=True)
                sorted_target = to_sort[sorted_indices][:,1:]
                total = torch.cat((bb[nb].unsqueeze(0),torch.zeros((36,140), dtype = torch.float32).cuda()))
                
                found = False
                
                max_prob = -np.inf
                
                for j in range(1,7):
                    
                    total_new = torch.clone(total)
                    total_new[j] = bb[cb]
                    total_new[5*j+2:5*j+7] = sorted_target
                    
                    mask_list_deg2s_passed = (list_deg2s_passed > 0).int()
                    mask_total_new = (total_new > 0).int()
                    
                    subtracted = mask_list_deg2s_passed - mask_total_new
                    non_negative_slices = (subtracted >= 0).all(dim=(1, 2))
                    
                    if  torch.sum(non_negative_slices==True)!=0:
                        found = True
                        max_prob = max(max_prob, torch.max(-diffs_passed[non_negative_slices.cpu()]))
                
                bb_prob+=max_prob
    
                if not found:
                    bb_died = True
                else:
                    num_average+=1
                                        
            for nb in range(1,7):
                if torch.sum(bb[nb])==0:
                    break
                target_central = torch.cat((bb[nb].unsqueeze(0), bb[0].unsqueeze(0), bb[5*nb+2:5*nb+7]), dim=0)
                to_sort = torch.cat(( get_degree(target_central[1:].unsqueeze(0)).T, target_central[1:]), dim=1)
                rows_as_lists = to_sort.tolist()
                sorted_indices = sorted(range(len(rows_as_lists)), key=lambda i: rows_as_lists[i], reverse=True)
                sorted_target = torch.cat((bb[nb].unsqueeze(0), to_sort[sorted_indices][:,1:]), dim=0)
                center_index = sorted_indices.index(0)+1
                assert(torch.all(bb[0] == sorted_target[center_index]))
                assert(sorted_target.shape[0]==7)
                extra = torch.zeros((30,140), dtype = torch.float32).cuda()
                total = torch.cat((sorted_target, extra), dim=0)
                
                mask_center_nb = torch.ones(6, dtype = torch.bool)
                mask_center_nb[nb-1] = False
                to_add = bb[1:7][mask_center_nb]
                    
                comparison = total[:7] == total[center_index]
                rows_equal_to_specific = comparison.all(dim=1) 
                equal_row_indices = torch.nonzero(rows_equal_to_specific, as_tuple=False).squeeze()
                if equal_row_indices.dim() == 0:
                    possible_center_indices = [equal_row_indices.item()]
                else:
                    possible_center_indices = equal_row_indices.tolist()
                    
                    
                found = False
                
                max_prob = -np.inf
                
                for center_index in possible_center_indices:
                    
                    total_copy = torch.clone(total)
                    
                    total_copy[5*center_index+2:5*center_index+7] = to_add
                    
                    mask_list_deg2s_passed = (list_deg2s_passed > 0).int()
                    mask_total = (total > 0).int()
                    subtracted = mask_list_deg2s_passed - mask_total
                    
                    non_negative_slices = (subtracted >= 0).all(dim=(1, 2))
                    if  torch.sum(non_negative_slices==True)!=0:
                        found = True
                        max_prob = max(max_prob, torch.max( - diffs_passed[non_negative_slices.cpu()]))
                
                bb_prob+=max_prob
                
                if not found:
                    bb_died = True
                else:
                    num_average+=1
            
            if not bb_died:
                filtered_bbs.append(bb.unsqueeze(0))      
            
            if not bb_died:
                final_deg2s.append(bb)
                final_probs.append(bb_prob)
                
            all_deg2s.append(bb)
            all_probs.append(bb_prob)
        
        if len(final_deg2s)==0:
            if use_neptune:
                run["fails/somehow not a single deg 2"].append(1)
            return None, None, None, None, None
        
        final_deg2s = torch.stack(final_deg2s)
        all_deg2s = torch.stack(all_deg2s)

        return True, final_deg2s, final_probs, all_deg2s, all_probs
  

def build_after_filtering(filtered_bbs, bb_probs, backup_bbs, backup_bb_probs, upper_bound, max_rank):
    filtered_bbs = convert_to_bb(filtered_bbs)
    backup_bbs = convert_to_bb(backup_bbs)
    
    building_mode = 'dfs'
    DFS_max_depth = min(max(7, max_rank//3), 20)
    
    if use_neptune:
        run["building/building_mode"] = building_mode
        run["building/DFS max depth"] = DFS_max_depth
    
    if building_mode == 'dfs':
        if use_neptune:
            run["building/building time limit"] = 900
        accuracy2, pred, depth, dfs_steps = build_molecule_from_bbs_DFS(DFS_max_depth=DFS_max_depth, bbs=filtered_bbs, backup_bbs = backup_bbs, backup_bb_probs = backup_bb_probs,  deg=2, model = model, criterion = criterion, gt_gradient = gradient, bbs_prob = bb_probs, upper_bound_atoms=upper_bound, feature_onehot_encoding=True, gt_ams=gt_ams, gt_ls=gt_ls if not is_gc else None, gt_fms=gt_fms)
        if use_neptune and pred is not None:
            run["building/successfully_built"].append(accuracy2)
            if dfs_steps is not None:
                run["building/DFS depth"].append(depth)
                run["building/DFS steps"].append(dfs_steps)
            
    return pred
def print_metric(metrics, total_metrics, n_iter):
    to_print = [['Metric', 'Current', 'Aggregate']]
    for metric in metrics:
        if 'r2' not in metric and 'f' not in metric:
            continue
        if metric not in total_metrics:
            total_metrics[metric] = 0
        total_metrics[metric] += metrics[metric]
        to_print.append([metric, f'{metrics[metric]*100:.2f}', f'{total_metrics[metric]/n_iter*100:.2f}'])
        
    from tabulate import tabulate
    print(tabulate(to_print, headers='firstrow', tablefmt='fancy_grid'))

if __name__ == '__main__':
        
    if use_neptune:
            run = neptune.init_run(
                project=args.neptune, 
                api_token=os.environ['NEPTUNE_API_KEY'], 
                name=run_name, 
                tags=["ours"], 
                dependencies="infer", 
            )
            run["model_info/feature_onehot_encoding"] = True
            run["dataset"] = dataset
    else:
        run = None
    
    if use_neptune:
        run["model_info/hidden size"] = model_args['hidden_size']
        run["model_info/node emb dim"] = model_args['node_embedding_dim']
        run["model_info/readout hidden dim"] = model_args['readout_hidden_dim']
        run["model_info/graph hidden dim"] = model_args['graph_embedding_dim']
        
    criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
    
    feat_dim = 140  # Example feature dimension
    num_cats = 2    # Example number of categories
    # saved_model_path = './graphnn-gradient-inversion-1/app/fedgraphnn/moleculenet_graph_clf/model/model_one_hot_new.pth'
    
    saved_model_path = None
    torch.manual_seed(42)
    
    model = get_model(model_args, feat_dim, num_cats)
    
    if saved_model_path is not None:
        model.load_state_dict(torch.load(saved_model_path))
    model.to('cuda')
        
    eval_model = get_model(model_args_eval, feat_dim, num_cats)
    eval_model.to('cuda')
    
    dataloader_dataset = DataLoader(load_dataset(dataset).shuffle(42).with_format("torch")['train'], batch_size=1, shuffle=False)
    
    gcn_model = model.gcn.cuda()
        
    total_smiles = 0
    total_reconstructed = 0
    total_failed = 0
    total_metrics = {}
    for i, batch in enumerate(dataloader_dataset):
        if i <44:
            continue
        assert(total_smiles == total_reconstructed+total_failed)

        print(i,batch['smiles'])
        
        mol_start_time = time.time()
        
        if use_neptune:
            run["smiles"].log(batch['smiles'][0])
        
        total_smiles += 1
        
        build_marker, filtered_bbs, bb_probs, all_bbs, all_bb_probs = filter_bbs(batch)

        if build_marker is None:     
            mol_end_time = time.time()
            if run is not None:
                run["time/molecule total time"].append(mol_end_time - mol_start_time)      
            total_failed += 1
            gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
            gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=True)).cuda()   
            gt_fms.requires_grad_(True)

            gt_ams = torch.tensor(get_A(gt_mol)).cuda()
            
            args = CustomArgs()
            
            metrics = evaluate_metrics(args, gt_ams, gt_fms, None, None, eval_model)
            print_metric(metrics, total_metrics, n_iter=i+1)
                        
            if use_neptune:
                log_metrics(metrics, run)
                
                img = Draw.MolToImage(gt_mol)
                img_byte_arr = io.BytesIO()
                img.save(img_byte_arr, format='PNG')
                img_byte_arr = img_byte_arr.getvalue()
                run[f"gt/{i}"].append(neptune.types.File.from_content(img_byte_arr, extension='png'))
            continue
        
        if build_marker:
            
            pred = build_after_filtering(filtered_bbs, bb_probs, all_bbs, all_bb_probs, upper_bound = 3*max_rank, max_rank=max_rank)
            
        if not build_marker:
                                    
            if filtered_bbs.dim() == 1:
                filtered_bbs = filtered_bbs.unsqueeze(0)
            
            nonzero_rows = torch.nonzero( torch.any(filtered_bbs != 0, dim=1) ).squeeze()
                        
            universal_adj_37 = torch.zeros((37,37)).cuda()
            for i in range(7):
                universal_adj_37[0,i] = 1
                universal_adj_37[i,0] = 1
            for i in range(1,7):
                universal_adj_37[i,5*i+2:5*i+7] = 1
                universal_adj_37[5*i+2:5*i+7,i] = 1
            for i in range(37):
                universal_adj_37[i,i] = 1
            
            X = filtered_bbs[nonzero_rows]
            
            A = universal_adj_37[nonzero_rows]
            A = A.T
            A = A[nonzero_rows]
                        
            pred = (A, X)
                
        if pred is None:
            run["fails/nothing to evaluate"].append(1)
            mol_end_time = time.time()
            run["time/molecule total time"].append(mol_end_time - mol_start_time)     
            
            gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
            gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=True)).cuda()   
            gt_fms.requires_grad_(True)

            gt_ams = torch.tensor(get_A(gt_mol)).cuda()
            
            args = CustomArgs()
            
            total_failed += 1
            metrics = evaluate_metrics(args, gt_ams, gt_fms, None, None, eval_model)
            
            print_metric(metrics, total_metrics, n_iter=i+1)

            if use_neptune:
                log_metrics(metrics, run)
                img = Draw.MolToImage(gt_mol)
                img_byte_arr = io.BytesIO()
                img.save(img_byte_arr, format='PNG')
                img_byte_arr = img_byte_arr.getvalue()
                run[f"gt/{i}"].append(neptune.types.File.from_content(img_byte_arr, extension='png'))
            continue
                
        recon_A = torch.tensor(pred[0]).cuda()
        recon_X = torch.tensor(pred[1]).cuda()
        recon_X = torch.where(recon_X <= 0, 0, 1).float()
                
        gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
        gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=True)).cuda()   
        gt_fms.requires_grad_(True)

        gt_ams = torch.tensor(get_A(gt_mol)).cuda()
        
        args = CustomArgs()
        
        if recon_A.dim() == 0:
            recon_A = recon_A.unsqueeze(0).unsqueeze(0)
            recon_X = recon_X.unsqueeze(0)

        metrics = evaluate_metrics(args, gt_ams, gt_fms, recon_A, recon_X, eval_model)
                
        if use_neptune:
            log_metrics(metrics, run)
        
        print_metric(metrics, total_metrics, n_iter=i+1)
            
                        
        gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
        gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=True)).cuda() 
        
        if use_neptune:
            img = Draw.MolToImage(gt_mol, size=(1200, 1200))
            img_byte_arr = io.BytesIO()
            img.save(img_byte_arr, format='PNG')
            img_byte_arr = img_byte_arr.getvalue()
            run[f"gt/{i}"].append(neptune.types.File.from_content(img_byte_arr, extension='png'))

            img_reconstruct = draw_atom(revert_one_hot_encoding_multiple(recon_X), recon_A)
            img_byte_arr2 = io.BytesIO()
            img_reconstruct.save(img_byte_arr2, format='PNG')
            img_byte_arr2 = img_byte_arr2.getvalue()
            run[f"gt/{i}"].append(neptune.types.File.from_content(img_byte_arr2, extension='png'))


            mol_end_time = time.time()
            run["time/molecule total time"].append(mol_end_time - mol_start_time)  
        else:
            img = Draw.MolToImage(gt_mol, size=(1200, 1200))
            plt.imshow(img)
            plt.savefig('gt.png', format='PNG')
            
            img_reconstruct = draw_atom(revert_one_hot_encoding_multiple(recon_X), recon_A)
            plt.imshow(img_reconstruct)
            plt.savefig('pred.png', format='PNG')
        total_reconstructed += 1