from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit import Chem
import numpy as np
from rdkit import RDConfig
import os
import deepchem as dc
import torch
import torch.nn as nn
from collections import deque
from gcn_readout import GcnMoleculeNetv2

from typing import Tuple
import yaml

import pandas as pd
import math


def get_partial_deg2_bb(bb, center_idx, max_deg):
    features = torch.zeros((max_deg**2, bb.X.shape[1])).cuda()
    features[:, :] = float('nan')
    former_idxs = torch.zeros(max_deg**2).cuda()
    former_idxs[:] = float('nan')
    
    neighbor_idxs = torch.nonzero(bb.A[center_idx].bool() & (torch.arange(len(bb.A))!=center_idx).cuda())[:, 0]
    features[:max_deg*len(neighbor_idxs):max_deg] = bb.X[neighbor_idxs]
    former_idxs[:max_deg*len(neighbor_idxs):max_deg] = neighbor_idxs
    
    for i, idx in enumerate(neighbor_idxs):
        second_neighbor_idxs = torch.nonzero(bb.A[idx].bool() & ~
            torch.isin(
                torch.arange(len(bb.A)).cuda(),
                torch.tensor([idx, center_idx]).cuda()
            )
        )[:, 0]
        features[max_deg*i+1:max_deg*i+1+len(second_neighbor_idxs)] = bb.X[second_neighbor_idxs]
        former_idxs[max_deg*i+1:max_deg*i+1+len(second_neighbor_idxs)] = second_neighbor_idxs

    return features, torch.cat((torch.tensor([center_idx]).cuda(),former_idxs))

def properly_structure_deg1(deg1s): 
        
    centers = deg1s[:, 0, :]
    neighbors = deg1s[:, 1:, :]
    degrees = get_degree(neighbors)
    to_sort = torch.cat((degrees.unsqueeze(2), neighbors), dim=-1)
    
    list = []
    
    for i in range(deg1s.shape[0]):
        rows_as_lists = to_sort[i].tolist()
        sorted_indices = sorted(range(len(rows_as_lists)), key=lambda i: rows_as_lists[i], reverse=True)
        sorted_tensor = to_sort[i][sorted_indices]
        list.append(sorted_tensor[:, 1:])
        
    all_sorted = torch.stack(list)
        
    assert(all_sorted.shape[1]==36)
    assert(all_sorted.shape[2]==140)    
    final = torch.cat((centers.unsqueeze(1), all_sorted), dim = 1)

    return final

def properly_structure_deg2(deg2s):
    centers = deg2s[:, 0, :]
    centers = centers.cuda()
    next = deg2s[:, 1:, :]
    degrees = get_degree(next)
    to_sort = torch.cat((degrees.unsqueeze(2), next), dim=-1)
    list = []
    
    for i in range(deg2s.shape[0]): # revert_one_hot_encoding_multiple(gt_deg2_bbs)     revert_one_hot_encoding_multiple(deg2s[i])
        
        rows_as_lists = to_sort[i].tolist()

        second_part = []
        second_part += [rows_as_lists[6:11]]
        second_part += [rows_as_lists[11:16]]
        second_part += [rows_as_lists[16:21]]
        second_part += [rows_as_lists[21:26]]
        second_part += [rows_as_lists[26:31]]
        second_part += [rows_as_lists[31:36]]
        
        p1 = second_part[0]
        sorted_indices1 = sorted(range(len(p1)), key=lambda i: p1[i], reverse=True)
        p1 = [p1[j] for j in sorted_indices1]
        p2 = second_part[1]
        sorted_indices2 = sorted(range(len(p2)), key=lambda i: p2[i], reverse=True)
        p2 = [p2[j] for j in sorted_indices2]
        p3 = second_part[2]
        sorted_indices3 = sorted(range(len(p3)), key=lambda i: p3[i], reverse=True)
        p3 = [p3[j] for j in sorted_indices3]
        p4 = second_part[3]
        sorted_indices4 = sorted(range(len(p4)), key=lambda i: p4[i], reverse=True)
        p4 = [p4[j] for j in sorted_indices4]
        p5 = second_part[4]
        sorted_indices5 = sorted(range(len(p5)), key=lambda i: p5[i], reverse=True)
        p5 = [p5[j] for j in sorted_indices5]
        p6 = second_part[5]
        sorted_indices6 = sorted(range(len(p6)), key=lambda i: p6[i], reverse=True)
        p6 = [p6[j] for j in sorted_indices6]

        second_part = [p1]+[p2]+[p3]+[p4]+[p5]+[p6]        
                
        new_rows_as_lists = rows_as_lists[0:6] + p1+p2+p3+p4+p5+p6
        
        part0 = []
        
        part0.append([new_rows_as_lists[0]] + new_rows_as_lists[6:11])
        part0.append([new_rows_as_lists[1]] + new_rows_as_lists[11:16])
        part0.append([new_rows_as_lists[2]] + new_rows_as_lists[16:21])
        part0.append([new_rows_as_lists[3]] + new_rows_as_lists[21:26])
        part0.append([new_rows_as_lists[4]] + new_rows_as_lists[26:31])
        part0.append([new_rows_as_lists[5]] + new_rows_as_lists[31:36])
        
         
        sorted_indices0 = sorted(range(len(part0)), key=lambda i: part0[i], reverse=True)
        
        part0 = [part0[j][0] for j in sorted_indices0]

        second_part = [second_part[j] for j in sorted_indices0]
                
        final = part0 + second_part[0] + second_part[1] + second_part[2] + second_part[3] + second_part[4] + second_part[5]        

        sorted_tensor = torch.tensor(final)
        list.append(sorted_tensor[:, 1:])
        
    all_sorted = torch.stack(list).cuda()
    assert(all_sorted.shape[1]==36)
    assert(all_sorted.shape[2]==140)    
    final = torch.cat((centers.unsqueeze(1), all_sorted), dim = 1)

    return final

def properly_structure_deg2_grouped(deg2s, former_idxs):
    centers = deg2s[:, 0, :]
    centers = centers.cuda()
    center_idxs = former_idxs[:, 0]
    next = deg2s[:, 1:, :]
    degrees = get_degree(next)
    to_sort = torch.cat((degrees.unsqueeze(2), next), dim=-1)
    list = []
    idx_list = []
    
    for i in range(deg2s.shape[0]): 
        
        rows_as_lists = to_sort[i].tolist()
        part0 = []
        idxs_list = former_idxs[i, 1:].tolist()
        new_former_idxs = []
        
        for j in range(6):
            
            relevant_rows = rows_as_lists[6*j+1:6*(j+1)]
            relevant_idxs = idxs_list[6*j+1:6*(j+1)]
            
            sorted_indices1 = sorted(range(len(relevant_rows)), key=lambda i: relevant_rows[i], reverse=True)
            relevant_rows = [relevant_rows[k] for k in sorted_indices1]
            relevant_idxs = [relevant_idxs[k] for k in sorted_indices1]

            part0 += [[rows_as_lists[6*j]] + relevant_rows]
            new_former_idxs += [[idxs_list[6*j]] + relevant_idxs] 

        sorted_indices0 = sorted(range(len(part0)), key=lambda i: part0[i], reverse=True)
        
        part0 = [part0[j] for j in sorted_indices0]     
        new_former_idxs = [new_former_idxs[j] for j in sorted_indices0]
        
        sorted_tensor = torch.tensor(part0).reshape(36, 141)
        sorted_idxs = torch.tensor(new_former_idxs).reshape(36)
        list.append(sorted_tensor[:, 1:])
        idx_list.append(sorted_idxs)
        
    all_sorted = torch.stack(list).cuda()
    idx_sorted = torch.stack(idx_list).cuda()
    assert(all_sorted.shape[1]==36)
    assert(all_sorted.shape[2]==140)    
    final = torch.cat((centers.unsqueeze(1), all_sorted), dim = 1)
    final_idxs = torch.cat((center_idxs.unsqueeze(1), idx_sorted), dim=1)

    return final, final_idxs


def get_degree(x):
    x = x[ :, :, 58:65]
    v, i = x.max( axis=2 )
    i[ v.isnan() ] = -1
    return i

def possible_feature_values():
    lists = []
    lists.append([1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 19.0, 20.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 32.0, 33.0, 34.0, 35.0, 38.0, 40.0, 42.0, 43.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 53.0, 56.0, 60.0, 64.0, 66.0, 70.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0])
    lists.append([-2.0, -1.0, 0.0, 1.0, 2.0, 3.0])
    lists.append([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    lists.append([0.0, 1.0, 2.0])
    lists.append([0.0, 1.0, 2.0, 3.0, 4.0, 6.0])
    lists.append([0.010080000385642052, 0.02014101855456829, 0.06941000372171402, 0.09012000262737274, 0.10812000185251236, 0.12010999768972397, 0.14007000625133514, 0.15998999774456024, 0.18998000025749207, 0.22990000247955322, 0.2430499941110611, 0.2698200047016144, 0.2808600068092346, 0.30974000692367554, 0.31973907351493835, 0.3206700086593628, 0.35453000664711, 0.3909800052642822, 0.40077999234199524, 0.47867000102996826, 0.509440004825592, 0.5199599862098694, 0.5493800044059753, 0.5584499835968018, 0.586929976940155, 0.5893300175666809, 0.6354600191116333, 0.6539000272750854, 0.7261000275611877, 0.7492200136184692, 0.7896000146865845, 0.7990400195121765, 0.8762000203132629, 0.9122400283813477, 0.9593999981880188, 0.9890625476837158, 1.0642000436782837, 1.0786800384521484, 1.1241199970245361, 1.1481800079345703, 1.1871099472045898, 1.2175999879837036, 1.2290558815002441, 1.2690399885177612, 1.3090612888336182, 1.3732800483703613, 1.4423999786376953, 1.5724999904632568, 1.625, 1.7303999662399292, 1.9507800340652466, 1.969670057296753, 2.0058999061584473, 2.0097081661224365, 2.043829917907715, 2.072000026702881, 2.0897998809814453])
    lists.append([0.0, 1.0])
    lists.append([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    
    return lists

def type_and_mass_pairs():
    clintox = torch.tensor([[1.0000e+00, 1.0080e-02],
        [5.0000e+00, 1.0812e-01],
        [6.0000e+00, 1.2011e-01],
        [7.0000e+00, 1.4007e-01],
        [8.0000e+00, 1.5999e-01],
        [9.0000e+00, 1.8998e-01],
        [1.3000e+01, 2.6982e-01],
        [1.4000e+01, 2.8086e-01],
        [1.5000e+01, 3.0974e-01],
        [1.5000e+01, 3.1974e-01],
        [1.6000e+01, 3.2067e-01],
        [1.7000e+01, 3.5453e-01],
        [2.0000e+01, 4.0078e-01],
        [2.2000e+01, 4.7867e-01],
        [2.4000e+01, 5.1996e-01],
        [2.5000e+01, 5.4938e-01],
        [2.6000e+01, 5.5845e-01],
        [2.9000e+01, 6.3546e-01],
        [3.0000e+01, 6.5390e-01],
        [3.3000e+01, 7.4922e-01],
        [3.4000e+01, 7.8960e-01],
        [3.5000e+01, 7.9904e-01],
        [4.3000e+01, 9.8906e-01],
        [5.3000e+01, 1.2291e+00],
        [5.3000e+01, 1.2690e+00],
        [5.3000e+01, 1.3091e+00],
        [7.8000e+01, 1.9508e+00],
        [7.9000e+01, 1.9697e+00],
        [8.0000e+01, 2.0059e+00],
        [8.1000e+01, 2.0097e+00],
        [8.3000e+01, 2.0898e+00]])
    
    tox21 = torch.tensor([[1.0000e+00, 1.0080e-02],
        [1.0000e+00, 2.0141e-02],
        [3.0000e+00, 6.9410e-02],
        [4.0000e+00, 9.0120e-02],
        [5.0000e+00, 1.0812e-01],
        [6.0000e+00, 1.2011e-01],
        [7.0000e+00, 1.4007e-01],
        [8.0000e+00, 1.5999e-01],
        [9.0000e+00, 1.8998e-01],
        [1.1000e+01, 2.2990e-01],
        [1.2000e+01, 2.4305e-01],
        [1.3000e+01, 2.6982e-01],
        [1.4000e+01, 2.8086e-01],
        [1.5000e+01, 3.0974e-01],
        [1.6000e+01, 3.2067e-01],
        [1.7000e+01, 3.5453e-01],
        [1.9000e+01, 3.9098e-01],
        [2.0000e+01, 4.0078e-01],
        [2.2000e+01, 4.7867e-01],
        [2.3000e+01, 5.0944e-01],
        [2.4000e+01, 5.1996e-01],
        [2.5000e+01, 5.4938e-01],
        [2.6000e+01, 5.5845e-01],
        [2.7000e+01, 5.8933e-01],
        [2.8000e+01, 5.8693e-01],
        [2.9000e+01, 6.3546e-01],
        [3.0000e+01, 6.5390e-01],
        [3.2000e+01, 7.2610e-01],
        [3.3000e+01, 7.4922e-01],
        [3.4000e+01, 7.8960e-01],
        [3.5000e+01, 7.9904e-01],
        [3.8000e+01, 8.7620e-01],
        [4.0000e+01, 9.1224e-01],
        [4.2000e+01, 9.5940e-01],
        [4.6000e+01, 1.0642e+00],
        [4.7000e+01, 1.0787e+00],
        [4.8000e+01, 1.1241e+00],
        [4.9000e+01, 1.1482e+00],
        [5.0000e+01, 1.1871e+00],
        [5.1000e+01, 1.2176e+00],
        [5.3000e+01, 1.2690e+00],
        [5.6000e+01, 1.3733e+00],
        [6.0000e+01, 1.4424e+00],
        [6.4000e+01, 1.5725e+00],
        [6.6000e+01, 1.6250e+00],
        [7.0000e+01, 1.7304e+00],
        [7.8000e+01, 1.9508e+00],
        [7.9000e+01, 1.9697e+00],
        [8.0000e+01, 2.0059e+00],
        [8.1000e+01, 2.0438e+00],
        [8.2000e+01, 2.0720e+00],
        [8.3000e+01, 2.0898e+00]])
    bbbp = torch.tensor([[1.0000e+00, 1.0080e-02],
        [5.0000e+00, 1.0812e-01],
        [6.0000e+00, 1.2011e-01],
        [7.0000e+00, 1.4007e-01],
        [8.0000e+00, 1.5999e-01],
        [9.0000e+00, 1.8998e-01],
        [1.1000e+01, 2.2990e-01],
        [1.5000e+01, 3.0974e-01],
        [1.6000e+01, 3.2067e-01],
        [1.7000e+01, 3.5453e-01],
        [2.0000e+01, 4.0078e-01],
        [3.5000e+01, 7.9904e-01],
        [5.3000e+01, 1.2690e+00]])
    all = torch.cat((clintox, tox21, bbbp), dim=0)
    all_unique = torch.unique(all, dim=0)
    list_atom_types = possible_feature_values()[0]
    list_atom_mass = possible_feature_values()[5]
        
    encoded = []
    for row in all_unique:
        atom_type = row[0]
        atom_mass = row[1]
        tensor_atom = torch.zeros(len(list_atom_types), dtype=torch.float)
        assert(atom_type in list_atom_types)
        index = list_atom_types.index(atom_type)
        tensor_atom[index] = 1
        tensor_mass = torch.zeros(len(list_atom_mass), dtype=torch.float)
        index = -3
        for ind, mass in enumerate(list_atom_mass):
            if math.isclose(mass, atom_mass, abs_tol=0.0001):
                index = ind
                break
        assert(index>=0)
        tensor_mass[index] = 1
        encoded.append(torch.cat((tensor_atom.unsqueeze(0), tensor_mass.unsqueeze(0)), dim=1))
    encoded = torch.cat(encoded, dim=0)
        
    return encoded

def get_A(mol: Chem.rdchem.Mol) -> torch.Tensor:
    adj = GetAdjacencyMatrix(mol)
    adj_mat = torch.tensor(adj, dtype=torch.float32)
    adj_mat += torch.eye(adj.shape[0]) # Adding the self-loops to the adjacency matrix
    return adj_mat

def get_atom_symbol(atomic_num):
    pt = Chem.GetPeriodicTable()
    return pt.GetElementSymbol(atomic_num)

def invert_X(X):
    atom = Chem.Atom(get_atom_symbol(int(X[0])))
    atom.SetFormalCharge(int(X[1]))
    atom.SetNoImplicit(True)
    atom.SetNumExplicitHs(int(X[4]))
    atom.SetIsAromatic(bool(X[5]))
    
    hybridization = int(X[6])  # Atom hybridization
    if hybridization==0:
        hybridization = Chem.HybridizationType.UNSPECIFIED
    elif hybridization==1:
        hybridization = Chem.HybridizationType.S
    elif hybridization==2:
        hybridization = Chem.HybridizationType.SP
    elif hybridization==3:
        hybridization = Chem.HybridizationType.SP2
    elif hybridization==4:
        hybridization = Chem.HybridizationType.SP3
    elif hybridization==5:
        hybridization = Chem.HybridizationType.SP3D
    elif hybridization==6:
        hybridization = Chem.HybridizationType.SP3D2
    atom.SetHybridization(hybridization)

    return atom

def get_X(mol: Chem.rdchem.Mol, feature_onehot_encoding=True) -> torch.Tensor:
    
    ft_mat = torch.rand(mol.GetNumAtoms(),8)

    for k, atom in enumerate(mol.GetAtoms()):
        # Chirality
        chiral_tag = atom.GetChiralTag()  # Chirality
        if chiral_tag == Chem.ChiralType.CHI_UNSPECIFIED:
            chiral_tag = 0

        # Hybridization            
        hybridization = atom.GetHybridization()  # Atom hybridization
        if hybridization==Chem.HybridizationType.S:
            hybridization = 1
        elif hybridization==Chem.HybridizationType.SP:
            hybridization = 2
        elif hybridization==Chem.HybridizationType.SP2:
            hybridization = 3
        elif hybridization==Chem.HybridizationType.SP3:
            hybridization = 4
        elif hybridization==Chem.HybridizationType.SP3D:
            hybridization = 5
        elif hybridization==Chem.HybridizationType.SP3D2:
            hybridization = 6

        ft_mat[k,:] = torch.tensor([
            atom.GetAtomicNum(),  # Atomic number
            atom.GetFormalCharge(), # Formal charge
            atom.GetDegree(), # Number of bonds (not counting Hs) 
            chiral_tag, 
            atom.GetTotalNumHs(),  # Number of explicit Hs
            atom.GetMass()/100,  # Atomic mass
            int(atom.GetIsAromatic()), 
            hybridization
            # atom.GetHybridization()  # Atom hybridization
        ], dtype=torch.float32)

    if not feature_onehot_encoding:
        return ft_mat
    
    lists = possible_feature_values()
    
    one_hots = []
    
    for i in range(len(ft_mat)):
        x = ft_mat[i]
        one_hot = torch.zeros(sum(len(l) for l in lists))
        so_far = 0
        for j in range(8):
            assert(x[j] in lists[j])
            one = so_far + lists[j].index(x[j])
            one_hot[one] = 1
            so_far+=len(lists[j])
        assert(one_hot.size(0)==140)
        one_hots.append(one_hot)
    
    ft_mat_encoded = torch.stack(one_hots)
    
    if feature_onehot_encoding==False:
        return ft_mat
    if feature_onehot_encoding==True:
        return ft_mat_encoded

# get the stuff from print_mol as pandas dataframe
# TODO: change such that input can also be just a list of features from X
def mol_df(mol: Chem.rdchem.Mol) -> pd.DataFrame:
    df = pd.DataFrame(columns=['Atom', 'Charge', '#Bonds', 'Chirality', '#Hs', 'mass', 'arom', 'hyb'])
    for k, atom in enumerate(mol.GetAtoms()):
        df.loc[k] = [
            atom.GetSymbol(), # Symbol
            atom.GetFormalCharge(), # Formal charge
            atom.GetDegree(), # Number of bonds (not counting Hs) 
            str(atom.GetChiralTag()), 
            atom.GetTotalNumHs(),  # Number of explicit Hs
            atom.GetMass(),  # Atomic mass
            atom.GetIsAromatic(), 
            str(atom.GetHybridization())
        ]
    return df

def get_Y(mol) -> torch.Tensor:
    adj_mat = get_A(mol)
    ft_mat = get_X(mol)    
    Y_mol = adj_mat@ft_mat
    return Y_mol     


def tensor2d_to_tuple(tensor: torch.Tensor) -> Tuple:
    return tuple(map(tuple, tensor.tolist()))

def check_for_small_cycle(adj):
    adjacency_matrix = torch.tensor(adj)
    num_nodes = adjacency_matrix.size(dim=0)
    for i in range(num_nodes):
        adjacency_matrix[i][i] = 0
    A2 = torch.matmul(adjacency_matrix, adjacency_matrix)
    A3 = torch.matmul(A2, adjacency_matrix)
    
    # Trace of A^3 (sum of diagonal elements)
    trace_A3 = torch.trace(A3)
    if trace_A3 > 0:
        return True

    # Check for cycle of length 4
    A4 = torch.matmul(A2, A2)
    
    for i in range(num_nodes):
        sum = 0
        for k in range(num_nodes):
            if k!=i:
                sum+=A2[i][k]
            if k==i:
                sum+=A2[i][i]*A2[i][i]
        if A4[i][i]>sum:
            return True
        
    return False    

def bfs_shortest_path_length(adj_matrix, start_node, end_node):
    num_nodes = len(adj_matrix)
    visited = [False] * num_nodes
    distance = [float('inf')] * num_nodes
    queue = deque([(start_node, 0)])  # Queue stores tuples of (node, current_distance)
    visited[start_node] = True
    distance[start_node] = 0
    
    while queue:
        current_node, current_distance = queue.popleft()
        
        if current_node == end_node:
            return current_distance
        
        for neighbor, is_connected in enumerate(adj_matrix[current_node]):
            if is_connected and not visited[neighbor]:
                visited[neighbor] = True
                new_distance = current_distance + 1
                distance[neighbor] = new_distance
                queue.append((neighbor, new_distance))
    
    return float('inf')


def is_connected(adjacency_matrix):
    num_nodes = adjacency_matrix.size(0)
    visited = torch.zeros(num_nodes, dtype=torch.bool)
    def dfs(node):
        visited[node] = True
        neighbors = adjacency_matrix[node].nonzero(as_tuple=False).squeeze()
        if neighbors.dim() == 0:
            return
        for neighbor in neighbors:
            if not visited[neighbor]:
                dfs(neighbor)
    dfs(0)
    return visited.all().item()

def get_model(model_args,feat_dim,num_cats):
    if model_args['act'] == 'relu':
        activation = nn.ReLU
    elif model_args['act'] == 'sigmoid':
        activation = nn.Sigmoid
    elif model_args['act'] == 'gelu':
        activation = nn.GELU
    else:
        raise NotImplementedError(f'No activation called {model_args["act"]} available')
    model = GcnMoleculeNetv2(
        feat_dim,
        model_args['hidden_size'],
        model_args['node_embedding_dim'],
        model_args['dropout'],
        model_args['readout_hidden_dim'],
        model_args['graph_embedding_dim'],
        num_cats,
        model_args['n_layers_gcn'],
        model_args['n_layers_readout'],
        activation,
        model_args['graph_class'],
        model_args['sparse_adjacency'],
    )
    
        
    return model
    

def revert_one_hot_encoding_multiple(X):
    lists = sum(possible_feature_values(), [])
    list_x_reverted = []
    list_alls = []
    for i in range(X.shape[0]):
        x = X[i]
        list_x_reverted = []
        if x.shape[0]==0:
            continue
        for ind in range(140):
            if x[ind]>0:
                list_x_reverted.append(lists[ind])
        list_alls.append(list_x_reverted)
    return tuple(list_alls)