from collections.abc import Iterable

try:
    from graph_reconstruction.utils import *
except ImportError:
    from utils import *
    
try:
    from graph_reconstruction.feature import *
except ImportError:
    from feature import *

from graph_reconstruction.extended_feature import *

try:
    from graph_reconstruction.buildingblocks import *
except ImportError:
    from buildingblocks import *

from graph_reconstruction.extended_buildingblocks import *
    
from rdkit import Chem
import torch

from graph_reconstruction.join_2_ext_bbs_better import *

from collections import deque

from utils_normalizations import normalize_adjacency
from utilities import graph_gradient_dist

import time

def revert_one_hot_encoding(X):
    lists = sum(possible_feature_values(), [])
    assert(len(lists)==140)
    list_alls = []
    
    for ind in range(len(lists)):
        if X[ind]==1:
            list_alls.append(lists[ind])
        
    return tuple(list_alls)

def revert_one_hot_encoding_old(x):
    lists = []
    lists += [1.0, 5.0, 6.0, 7.0, 8.0, 9.0, 13.0, 14.0, 15.0, 16.0, 17.0, 20.0, 22.0, 24.0, 25.0, 26.0, 29.0, 30.0, 33.0, 34.0, 35.0, 43.0, 53.0, 78.0, 79.0, 80.0, 81.0, 83.0]
    lists += [-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]
    lists += [0.0, 1.0, 2.0, 3.0, 4.0, 6.0]
    lists += [0.0, 1.0, 2.0]
    lists += [0.0, 1.0, 2.0, 3.0]
    lists += [0.010080000385642052, 0.10812000185251236, 0.12010999768972397, 0.14007000625133514, 0.15998999774456024, 0.18998000025749207, 0.2698200047016144, 0.2808600068092346, 0.30974000692367554, 0.31973907351493835, 0.3206700086593628, 0.35453000664711, 0.40077999234199524, 0.47867000102996826, 0.5199599862098694, 0.5493800044059753, 0.5584499835968018, 0.6354600191116333, 0.6539000272750854, 0.7492200136184692, 0.7896000146865845, 0.7990400195121765, 0.9890625476837158, 1.2290558815002441, 1.2690399885177612, 1.3090612888336182, 1.9507800340652466, 1.969670057296753, 2.0058999061584473, 2.0097081661224365, 2.0897998809814453]
    lists += [0.0, 1.0]
    lists += [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
    list_x_reverted = []
    assert(len(x)==140)
    assert(len(lists)==140)
    for ind in range(len(x)):
        if x[ind]==1:
            list_x_reverted.append(lists[ind])
    return tuple(list_x_reverted)

def check_tensors_within_tolerance(list1, list2, tolerance=1e-4):
    if len(list1) != len(list2):
        raise ValueError("The two lists must have the same length")
    
    for tensor1, tensor2 in zip(list1, list2):

        # Ensure both tensors are of the same shape
        if tensor1.shape != tensor2.shape:
            raise ValueError("All corresponding tensors must have the same shape")
        
        # Compute the absolute difference
        diff = torch.abs(tensor1 - tensor2)
        
        # Check if all differences are less than or equal to the tolerance
        if not torch.all(diff <= tolerance):
            return False
    
    return True

def l2_dist_grad(adv_grad, true_grad):
    l2_distances = []
    for tensor1, tensor2 in zip(adv_grad, true_grad):
        # Compute squared difference
        squared_diff = (tensor1 - tensor2) ** 2

        # Sum of squared differences
        sum_squared_diff = squared_diff.sum()

        # L2 distance
        l2_distance = torch.sqrt(sum_squared_diff)

        # Append to distances list
        l2_distances.append(l2_distance.item())

    # Print the L2 distances
    print("L2 Distances:", l2_distances)
    mean_l2_distance = torch.tensor(l2_distances).mean()
    return mean_l2_distance.item()


def build_molecule_from_bbs_DFS(DFS_max_depth, bbs, backup_bbs, deg, model, criterion, gt_gradient, bbs_prob=None, backup_bb_probs = None, mol=None, upper_bound_atoms = None, feature_onehot_encoding=None, time_limit_per_branch=900, time_limit_global=1800, grad_limit=500, gt_fms=None, gt_ams=None, gt_ls=None): # Here deg is the degree building blocks from which we'll be recovering the molecule. For now it seems we will be able to recoer deg 1
    
    total_start_time = time.time()
    
    best_recon = None
    best_recon_dist = np.inf
    
    if bbs_prob==None:
        bbs = NewBuildingBlocks.from_molecule(mol,deg)
    else:
        bbs_sorted = sorted(zip(bbs_prob,bbs), key=lambda x: x[0], reverse=False)
        bbs = [x[1] for x in bbs_sorted]
        
        backup_bbs_sorted = sorted(zip(backup_bb_probs, backup_bbs), key = lambda x: x[0], reverse = False)
        backup_bbs = [x[1] for x in backup_bbs_sorted]
        
    for use_backups in [False, True]:
        
        if use_backups:
            print('Previous building failed, now using backups')
            
        print('/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/')
        
        time_left = time_limit_per_branch
        start_time = time.time()
        for k, beg in enumerate(bbs[::-1][:3]):
            
            DFS_steps=1
            
            if len(beg.connections) == 0:
                continue
            
            first = beg.connections[0]

            print(f'We start building beginning {k+1} from atom {first}', flush=True)
            
            possible_beginnings_ext = join_deg_2_bb_list_with_cycles_gpu(beg, first, backup_bbs if use_backups else bbs, 6)
            possible_beginnings_ext.reverse()
            possible_beginnings_ext = possible_beginnings_ext[:6]

            enum = 0
            max_index_used = 0
            
            list_of_edges_DFS_tree = []
            
            dfs_sequence = [0]
            
            adding_children = [0]
                    
            for i, t in enumerate(possible_beginnings_ext): # We have to iterate over all possibilities for the beginning in order not to miss the corrct solution
                
                enum+=1
                
                all_possible=[(t,2,max_index_used+1, grad_limit)]
                list_of_edges_DFS_tree.append((0,max_index_used+1))
                max_index_used += 1

                while len(all_possible)!=0: # We preform a simple BFS to track all possibilities

                    top = all_possible[0][0]
                    depth = all_possible[0][1]
                    parent_index = all_possible[0][2]
                    grad_limit_iter = all_possible[0][3]
                    all_possible.pop(0)
                    
                    adding_children.append(parent_index)

                    DFS_steps+=1
                    
                    if len(top.connections)==0 or time.time() - start_time > time_limit_per_branch or depth>DFS_max_depth or (upper_bound_atoms is not None and len(top.A)>upper_bound_atoms): # If the current has no dangling nodes, we have reached a possible solution and are adding it to the list of sols 
                        if len(top.connections) == 0:
                            print(f'DFS Found one with {len(top.A)}, depth is {depth}', flush=True)
                        if grad_limit_iter >= 1:
                            l2_diff = graph_gradient_dist(normalize_adjacency(torch.tensor(top.A)), torch.tensor(top.X), gt_gradient, model, criterion, gt_fms, gt_ams, gt_ls)
                            grad_limit_iter -= 1
                            
                            if l2_diff<1e-6:
                                
                                return l2_diff, (top.A, top.X), depth, DFS_steps
                            
                            elif l2_diff < best_recon_dist:
                                
                                best_recon_dist = l2_diff
                                best_recon = (top.A, top.X)

                        continue
                                        
                    for i in range(len(top.A)):
                        deg1 = sum(top.A[i])-1
                        if feature_onehot_encoding is False:
                            deg_gt = top.X[i][2]
                        else:
                            deg_gt = get_degree(top.X[i])

                        if deg1>deg_gt:
                            continue
                    
                    i = top.connections[0] # Again, doesn't matter which end we start adding to, might as well start with [0]
                                        
                    res= join_ext_bb_list_with_cycles_gpu(top,i,backup_bbs if use_backups else bbs,6, cycles_time_limit=60)
                                        
                    if len(res) == 0:
                        if grad_limit_iter >= 1: 
                            l2_diff = graph_gradient_dist(normalize_adjacency(torch.tensor(top.A)), torch.tensor(top.X), gt_gradient, model, criterion, gt_fms, gt_ams, gt_ls)
                            grad_limit_iter -= 1
                            
                            if l2_diff < best_recon_dist:
                                best_recon_dist = l2_diff
                                best_recon = (top.A, top.X)
                        continue
                    
                    for f in res:
                        if f is not None:
                            all_possible.insert(0,(f,depth+1,max_index_used+1, grad_limit_iter/len(res)))
                            max_index_used+=1
                            list_of_edges_DFS_tree.append((parent_index, max_index_used))
                                    
                    if time.time()>total_start_time + time_limit_global:
                        return best_recon_dist, best_recon, None, None
                    
            
            
    print(f'We (DFS) did not reach the true solution, instead found 1 with distance {best_recon_dist}', flush=True)
    return best_recon_dist, best_recon, None, None