import pickle

from rdkit import Chem
import torch
from torch_geometric.data import Data
from torch_geometric.utils import subgraph


def mol_to_torch_geometric(mol, atom_encoder, smiles, cfg):
    adj = torch.from_numpy(Chem.rdmolops.GetAdjacencyMatrix(mol, useBO=True))
    edge_index = adj.nonzero().contiguous().T
    bond_types = adj[edge_index[0], edge_index[1]]
    bond_types[bond_types == 1.5] = 4
    edge_attr = bond_types.long()

    # Unfortunately datasets such as zinc just do not have a 3d option
    pos = None
    if cfg.features.use_3d:
        pos = torch.tensor(mol.GetConformers()[0].GetPositions()).float()
        pos = pos - torch.mean(pos, dim=0, keepdim=True)

    atom_types = []
    all_charges = []
    for atom in mol.GetAtoms():
        atom_symbol = atom.GetSymbol()
        atom_charge = atom.GetFormalCharge()

        # We can append the charge even if we are using the partial charges
        # At training time, we will just ignore them / set them to zero
        if(cfg.features.charges_policy in ["no", "partial"]):
            all_charges.append(atom_charge)        # TODO: check if implicit Hs should be kept
        elif(cfg.features.charges_policy == "dictionary"):
            all_charges.append(0)

            #if the charge is not neutral
            if(atom_charge != 0):
                #this is necessary, as the sign "+" is lost when converting
                #atom_charge > 0 to a string. If charge < 0, the "-" is already embedded
                sign = ""
                if(atom_charge > 0):
                    sign = "+"

                #if the charge is not neutral, then its string in
                #the "types" dictionary is of the form <atom_symbol><formal charge>
                actual_atom_symbol = atom_symbol + sign + str(atom_charge)
                
                #check if the actual_atom_symbol is in the types dictionary.
                #if present, it means that we want to keep track of that
                #non-neutral version of the atom. Otherwise, we do not keep
                #the molecule.
                if(actual_atom_symbol in atom_encoder):
                    atom_symbol = actual_atom_symbol
                else:
                    return None
        
        atom_types.append(atom_encoder[atom_symbol])
        

    atom_types = torch.Tensor(atom_types).long()
    all_charges = torch.Tensor(all_charges).long()

    data = Data(x=atom_types, edge_index=edge_index, edge_attr=edge_attr, pos=pos, charges=all_charges,
                smiles=smiles, guidance=None, node_stats=None, edge_stats=None, n_nodes=None)
    return data


def remove_hydrogens(data: Data, cfg):
    to_keep = data.x > 0
    new_edge_index, new_edge_attr = subgraph(to_keep, data.edge_index, data.edge_attr, relabel_nodes=True,
                                             num_nodes=len(to_keep))
    new_pos = None
    if cfg.features.use_3d:
        new_pos = data.pos[to_keep] - torch.mean(data.pos[to_keep], dim=0)


    return Data(x=data.x[to_keep] - 1,         # Shift onehot encoding to match atom decoder
                pos=new_pos,
                charges=data.charges[to_keep] if data.charges != None else None,
                edge_index=new_edge_index,
                edge_attr=new_edge_attr,
                smiles=data.smiles)


def save_pickle(array, path):
    with open(path, 'wb') as f:
        pickle.dump(array, f)


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


class Statistics:
    def __init__(self, num_nodes, node_types, edge_types, cfg):
        self.cfg=cfg
        self.num_nodes = num_nodes
        print("NUM NODES IN STATISTICS", num_nodes)
        self.node_types = node_types
        self.edge_types = edge_types

class StatisticsMolecule(Statistics):
    def __init__(self, num_nodes, node_types, edge_types, cfg,
                 charge_types, valencies, bond_lengths, bond_angles):
        super().__init__(num_nodes, node_types, edge_types, cfg)

        self.valencies = valencies
        
        self.charge_types = charge_types
        self.bond_lengths = bond_lengths
        self.bond_angles = bond_angles