import torch
import time
import yaml
from graph_reconstruction.buildingblocks import *
from graph_reconstruction.build_molecule import *
from utilities import *
from utils_normalizations import normalize_adjacency, normalize_features
          


def deg1_bbs(batch, individual_features, gradient, model, criterion, run, gt_ams, tol1=5e-6, tol2=1e-7, gt_ls = None, dataset='zpn/clintox'):

    gradient_GCN_1 = gradient[0].t()
    gradient_GCN_2 = gradient[1].t()
    
    gcn_model = model.gcn.cuda()
    
    gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
    
    gt_fms = normalize_features(torch.tensor(get_X(gt_mol, feature_onehot_encoding=True)).cuda(), dataset = dataset)
    gt_fms.requires_grad_(True)
    
    num_nodes = gt_fms.shape[0]
    
    _, R_K_norm_GCN_1 = get_layer_decomp(gradient_GCN_1, tol=tol1)
    R_K_norm_GCN_1 = R_K_norm_GCN_1.cuda()
    
    ind_f_copy = torch.clone(individual_features)

    print('Filtering started')

    start_filtering = time.time()
        
    diffs = check_if_in_span(R_K_norm_GCN_1, ind_f_copy)
    diffs = torch.log10(diffs)
    diffs = torch.where(diffs == -float('inf'), torch.tensor(-7.0), diffs)

    del ind_f_copy
    end_filtering = time.time()

    if run is not None:
        run["filtering/deg0/filtering_time"].append(end_filtering-start_filtering)

    static_deg0_threshold = -3
    if run is not None:
        run["static thresholds/deg 0 threshold"] = static_deg0_threshold
    
    killed_too_much = False
    
    if torch.sum(diffs<static_deg0_threshold).item()>500:
        if run is not None:
            run["filtering/deg 0 too many"].append(1)
        sorted_diffs, _ = torch.sort(diffs)
        first500 = sorted_diffs[500]
        individual_features_passed = individual_features[diffs<first500]
        killed_too_much = True
    else:
        individual_features_passed = individual_features[diffs<static_deg0_threshold]
    
    zero_degrees = individual_features_passed[(get_degree(individual_features_passed.unsqueeze(0)) == 0).squeeze(0)]
    
    for single_X in zero_degrees:
        A = torch.tensor([1]).unsqueeze(0).cuda()
        if graph_gradient_dist(normalize_adjacency(A), single_X.unsqueeze(0), gradient, model, criterion, gt_fms=gt_fms, gt_ams=gt_ams, gt_ls=gt_ls) < 1e-6:
            return False, single_X
        
    list_degrees = list(set(get_degree(individual_features_passed.unsqueeze(0))[0].tolist()))
    
    def batched_combinations(indices, r, batch_size):
        all_combinations = torch.combinations(indices, r=r, with_replacement=True)
        for i in range(0, all_combinations.size(0), batch_size):
            yield all_combinations[i: i + batch_size]
    
    all_passed_deg1_spancheck = 0    
    list_passed_deg1s = []
    global survived 
    survived = True
    
    start_time = time.time()
    time_limit = 120
    
    if run is not None:
        run["time limits/deg 1 bb time limit"] = time_limit
    
    killed_time = False
    
    static_threshold_deg1 = -2
                
    if run is not None:
        run["static thresholds/deg 1 threshold"] = static_threshold_deg1
    
    with torch.no_grad():
        for i in range(0,7):
            
            if i not in list_degrees:
                continue
            
            combinations_gen = batched_combinations(torch.arange(individual_features_passed.size(0)), r=i, batch_size=5000)

            if killed_time:
                break
            
            for batch_indices in combinations_gen:
                if time.time() > start_time + time_limit:
                    print('Time has killed this one')
                    killed_time = True
                    break
                
                correct_degree = (get_degree(individual_features_passed.unsqueeze(0))==i)            
                indices_centers = torch.where(correct_degree)[1]
                indices_center_repeated = indices_centers.repeat(batch_indices.shape[0]).unsqueeze(1)
                combination_indices_repeated = batch_indices.repeat_interleave(indices_centers.shape[0], dim=0).cuda()
                all_indices_batch = torch.cat((indices_center_repeated, combination_indices_repeated), dim=1)

                combination_tensors = individual_features_passed[batch_indices]

                tensors_correct_degree = individual_features_passed[correct_degree.squeeze(0)].unsqueeze(1)
                tensors_correct_degree_repeated = tensors_correct_degree.repeat(combination_tensors.shape[0],1,1)
                combination_tensors_repeated = combination_tensors.repeat_interleave(tensors_correct_degree.shape[0], dim=0)
                batch_deg1 = torch.cat((tensors_correct_degree_repeated, combination_tensors_repeated), dim=1)

                AXs = get_AX_deg1(batch_deg1).squeeze(1)
                bbs_output = torch.mm(AXs, gcn_model.W_list[0].cuda()).cuda()
                after_relu = gcn_model.act(bbs_output)
                
                _, R_K_norm_GCN_2 = get_layer_decomp(gradient_GCN_2, tol=tol2)
                R_K_norm_GCN_2 = R_K_norm_GCN_2.cuda()


                diffs = check_if_in_span(R_K_norm_GCN_2, after_relu)

                diffs = torch.log10(diffs)
                diffs = torch.where(diffs == -float('inf'), torch.tensor(-7.0), diffs)

                all_passed_deg1_spancheck += torch.sum(diffs<static_threshold_deg1)

                batch_deg1 = batch_deg1[diffs<static_threshold_deg1]

                zero_rows = torch.zeros(batch_deg1.shape[0], 37-batch_deg1.shape[1], 140).cuda()
                batch_deg1 = torch.cat((batch_deg1, zero_rows), dim=1)

                del zero_rows

                if batch_deg1.shape[0]!=0:
                    list_passed_deg1s.append(properly_structure_deg1(batch_deg1))
    
                torch.cuda.empty_cache()
    
    end_time = time.time()
    
    deg1_time = end_time - start_time
    
    if run is not None:
        run["filtering/deg 1 time for b and f"].append(deg1_time)
    
    print('Just generated and filtered all deg1s')    

    if run is not None and survived:
        run["filtering/deg 1/all_passed"].append(all_passed_deg1_spancheck)    
    
    final_deg1 = torch.cat(list_passed_deg1s)
        
    del list_passed_deg1s
    
    print('Just going out of bbs_deg1 function')    
    degrees = get_degree(final_deg1)[:, 1:]
    indices_0_dangs = torch.nonzero(torch.all(degrees <= 1, dim=1)).squeeze()
    
    if indices_0_dangs.numel()==1:
        indices_0_dangs = [indices_0_dangs.tolist()]
    else:
        indices_0_dangs = indices_0_dangs.tolist()
        
    for ind in indices_0_dangs:
        size = torch.sum(degrees[ind]>0)+1
        X = final_deg1[ind][:size]
        A = torch.eye(size)
        A[:,0] = 1
        A[0,:] = 1
        
        if graph_gradient_dist(normalize_adjacency(A), X, gradient, model, criterion, gt_fms=gt_fms, gt_ams=gt_ams, gt_ls=gt_ls) < 1e-6:
                                
                return False, final_deg1[ind]   
    
    return True, final_deg1