import os
from copy import deepcopy
from typing import Optional, Union, Dict

import torch
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops
from torchmetrics import Metric, MeanSquaredError, MeanAbsoluteError,MetricCollection,KLDivergence
import pytorch_lightning as pl
from omegaconf import OmegaConf, open_dict
import wandb

from didigress.analysis.rdkit_functions import Molecule


# from dgd.ggg_utils_deps import approx_small_symeig, our_small_symeig,extract_canonical_k_eigenfeat
# from dgd.ggg_utils_deps import  ensure_tensor, get_laplacian, asserts_enabled


class NoSyncMetricCollection(MetricCollection):
    def __init__(self,*args,**kwargs):
        super().__init__(*args,**kwargs) #disabling syncs since it messes up DDP sub-batching


class NoSyncMetric(Metric):
    def __init__(self):
        super().__init__(sync_on_compute=False,dist_sync_on_step=False) #disabling syncs since it messes up DDP sub-batching


class NoSyncKL(KLDivergence):
    def __init__(self):
        super().__init__(sync_on_compute=False,dist_sync_on_step=False) #disabling syncs since it messes up DDP sub-batching


class NoSyncMSE(MeanSquaredError):
    def __init__(self):
        super().__init__(sync_on_compute=False, dist_sync_on_step=False) #disabling syncs since it messes up DDP sub-batching


class NoSyncMAE(MeanAbsoluteError):
    def __init__(self):
        super().__init__(sync_on_compute=False,dist_sync_on_step=False) #disabling syncs since it messes up DDP sub-batching>>>>>>> main:utils.py

# Folders
def create_folders(args):
    try:
        # os.makedirs('checkpoints')
        os.makedirs('graphs', exist_ok=True)
        os.makedirs('chains', exist_ok=True)
    except OSError:
        pass

    try:
        # os.makedirs('checkpoints/' + args.general.name)
        os.makedirs('graphs/' + args.general.name, exist_ok=True)
        os.makedirs('chains/' + args.general.name, exist_ok=True)
    except OSError:
        pass


def to_dense(data, dataset_info, device=None):
    use_charges = dataset_info.cfg.features.use_charges
    use_3d      = dataset_info.cfg.features.use_3d
    use_ins_del = dataset_info.cfg.features.use_ins_del

    X, node_mask = to_dense_batch(x=data.x, batch=data.batch)
    max_num_nodes = X.size(1)

    edge_index, edge_attr = remove_self_loops(data.edge_index, data.edge_attr)
    E = to_dense_adj(edge_index=edge_index, batch=data.batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)

    charges = None
    if(use_charges):
        charges, _ = to_dense_batch(x=data.charges, batch=data.batch)

    pos = None
    if(use_3d):
        pos, _ = to_dense_batch(x=data.pos, batch=data.batch)
        pos = pos.float()
        assert pos.mean(dim=1).abs().max() < 1e-3

    X, E, charges = dataset_info.to_one_hot(X, E=E, node_mask=node_mask, charges=charges)
    y = X.new_zeros((X.shape[0], 0)) #do NOT move above X, charges, E

    insert_time=None
    delt_mask=None
    if(use_ins_del):
        #int(data.batch.max()) + 1 should be the batch size
        batch_sz = int(data.batch.max()) + 1
        insert_time = torch.zeros((batch_sz, max_num_nodes, 1), device=X.device)
        delt_mask = torch.zeros((batch_sz, max_num_nodes, 1), device=X.device)

    if device is not None:
        X = X.to(device)
        E = E.to(device)
        y = y.to(device)
        if(use_3d): pos = pos.to(device)
        node_mask = node_mask.to(device)
        if(use_ins_del): 
            insert_time = insert_time.to(device)
            delt_mask   = delt_mask.to(device)

    data = PlaceHolder(X=X, E=E, y=y, charges=charges, pos=pos, node_mask=node_mask, 
                       guidance=data.guidance, node_stats=data.node_stats, 
                       edge_stats=data.edge_stats, charge_types=data.charge_types, 
                       insert_time=insert_time, delt_mask=delt_mask, n_nodes=data.n_nodes)
    return data.mask()


class PlaceHolder:
    def __init__(self, X, E, y, charges=None, pos=None, t_int=None, t=None, node_mask=None, guidance=None,
                 node_stats=None, edge_stats=None, charge_types=None, insert_time=None, delt_mask=None,
                 n_nodes=None):
        self.X          = X
        self.E          = E
        self.y          = y

        self.charges    = charges
        self.pos        = pos

        self.t_int      = t_int
        self.t          = t
        self.node_mask  = node_mask

        self.guidance   = guidance

        self.node_stats   = node_stats
        self.edge_stats   = edge_stats
        self.charge_types = charge_types
        self.insert_time  = insert_time
        self.delt_mask    = delt_mask
        self.n_nodes      = n_nodes

    def device_as(self, x: torch.Tensor):
        """ Changes the device and dtype of X, E, y. """
        self.X      = self.X.to(x.device) if self.X is not None else None
        self.E      = self.E.to(x.device) if self.E is not None else None
        self.y      = self.y.to(x.device) if self.y is not None else None

        self.charges= self.charges.to(x.device) if self.charges is not None else None
        self.pos    = self.pos.to(x.device) if self.pos is not None else None
        self.guidance = self.guidance.to(x.device) if self.guidance is not None else None

        self.node_stats   = self.node_stats.to(x.device) if self.node_stats is not None else None
        self.edge_stats   = self.edge_stats.to(x.device) if self.edge_stats is not None else None
        self.charge_types = self.charge_types.to(x.device) if self.charge_types is not None else None
        self.insert_time  = self.insert_time.to(x.device) if self.insert_time is not None else None
        self.delt_mask    = self.delt_mask.to(x.device) if self.delt_mask is not None else None
        self.n_nodes      = self.n_nodes.to(x.device) if self.n_nodes is not None else None

        return self

    def mask(self, node_mask=None, assert_E=True):
        if node_mask is None:
            assert self.node_mask is not None
            node_mask = self.node_mask
        bs, n = node_mask.shape
        x_mask = node_mask.unsqueeze(-1)          # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)             # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)             # bs, 1, n, 1
        diag_mask = ~torch.eye(n, dtype=torch.bool,
                               device=node_mask.device).unsqueeze(0).expand(bs, -1, -1).unsqueeze(-1)  # bs, n, n, 1
        if self.X is not None:
            self.X = self.X * x_mask
        if self.charges is not None:
            self.charges = self.charges * x_mask
        if self.E is not None:
            self.E = self.E * e_mask1 * e_mask2 * diag_mask
        if self.pos is not None:
            self.pos = self.pos * x_mask
            self.pos = self.pos - self.pos.mean(dim=1, keepdim=True)
        if self.insert_time is not None:
            self.insert_time = self.insert_time * x_mask
        if self.delt_mask is not None:
            self.delt_mask = self.delt_mask * x_mask
            self.delt_mask = self.delt_mask.bool()
        if self.node_mask is not None:
            self.node_mask = self.node_mask.unsqueeze(-1) * x_mask 
            self.node_mask = self.node_mask.squeeze(-1).bool()

        # There are situations where the matrix is temporarily asymmetric.
        # For instance, when we are masking the DELt it usually happens when
        # we haven't yet sampled the categories of each node => the matrix
        # is not symmetric yet and this is expected.
        if(assert_E):
            assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)), f"has nan: {torch.isnan(self.E).any()}"
        return self

    def collapse(self, collapse_charges=None, use_charges=False):
        copy = self.copy()
        copy.X = torch.argmax(self.X, dim=-1)
        if(use_charges):
            copy.charges = collapse_charges.to(self.charges.device)[torch.argmax(self.charges, dim=-1)]
        else: copy.charges=None
        copy.E = torch.argmax(self.E, dim=-1)
        x_mask = self.node_mask.unsqueeze(-1)  # bs, n, 1
        e_mask1 = x_mask.unsqueeze(2)  # bs, n, 1, 1
        e_mask2 = x_mask.unsqueeze(1)  # bs, 1, n, 1
        copy.X[self.node_mask == 0] = - 1
        if(use_charges): copy.charges[self.node_mask == 0] = 1000
        copy.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
        return copy

    def __repr__(self):
        charge_str = ""
        if(self.charges != None):
            charge_str = f"charges: {self.charges.shape if type(self.charges) == torch.Tensor else self.charges} -- "

        pos_str = ""
        if(self.pos != None):
            pos_str = f"pos: {self.pos.shape if type(self.pos) == torch.Tensor else self.pos} -- "

        guidance_str = ""
        if(self.guidance != None):
            guidance_str = f"guidance: {self.guidance.shape if type(self.guidance) == torch.Tensor else self.guidance} -- "

        #TODO: charges as well? they are kinda large...
        node_stats_str = ""
        if(self.node_stats != None):
            node_stats_str = f"node stats: {self.node_stats.shape if type(self.node_stats) == torch.Tensor else self.node_stats} -- "

        edge_stats_str = ""
        if(self.edge_stats != None):
            edge_stats_str = f"edge stats: {self.edge_stats.shape if type(self.edge_stats) == torch.Tensor else self.edge_stats} -- "

        charge_types_str = ""
        if(self.charge_types != None):
            charge_types_str = f"charge_types: {self.charge_types.shape if type(self.charge_types) == torch.Tensor else self.charge_types} -- "

        insert_time_str = ""
        if(self.insert_time != None):
            insert_time_str = f"insert times: {self.insert_time.shape if type(self.insert_time) == torch.Tensor else self.insert_time} -- "
        
        delt_mask_str = ""
        if(self.delt_mask != None):
            delt_mask_str = f"delt_mask: {self.delt_mask.shape if type(self.delt_mask) == torch.Tensor else self.delt_mask} -- "

        n_nodes_str = ""
        if(self.n_nodes != None):
            n_nodes_str = f"n_nodes: {self.n_nodes.shape if type(self.n_nodes) == torch.Tensor else self.n_nodes} -- "

        t_int_str = ""
        if(self.t_int != None):
            t_int_str = f"t_int: {self.t_int.shape if type(self.t_int) == torch.Tensor else self.t_int} -- "

        t_str = ""
        if(self.t != None):
            t_str = f"t: {self.t.shape if type(self.t) == torch.Tensor else self.t} -- "

        return (f"X: {self.X.shape if type(self.X) == torch.Tensor else self.X} -- " +
                f"E: {self.E.shape if type(self.E) == torch.Tensor else self.E} -- " +
                f"y: {self.y.shape if type(self.y) == torch.Tensor else self.y} -- " +
                f"node_mask: {self.node_mask.shape if type(self.node_mask) == torch.Tensor else self.node_mask}" +
                charge_str + pos_str + guidance_str + node_stats_str + edge_stats_str + charge_types_str +
                insert_time_str + delt_mask_str + n_nodes_str + t_int_str + t_str) 


    def copy(self):
        return PlaceHolder(X=self.X.detach().clone() if self.X is not None else None,
            charges=self.charges.detach().clone() if self.charges is not None else None,
            E=self.E.detach().clone() if self.E is not None else None, 
            y=self.y.detach().clone() if self.y is not None else None, 
            pos=self.pos.detach().clone() if self.pos is not None else None, 
            t_int=self.t_int.detach().clone() if self.t_int is not None else None, 
            t=self.t.detach().clone() if self.t is not None else None,
            node_mask=self.node_mask.detach().clone() if self.node_mask is not None else None,
            guidance=self.guidance.detach().clone() if self.guidance is not None else None,
            node_stats=self.node_stats.detach().clone() if self.node_stats is not None else None,
            edge_stats=self.edge_stats.detach().clone() if self.edge_stats is not None else None,
            charge_types=self.charge_types.detach().clone() if self.charge_types is not None else None,
            insert_time=self.insert_time.detach().clone() if self.insert_time is not None else None,
            delt_mask=self.delt_mask.detach().clone() if self.delt_mask is not None else None,
            n_nodes=self.n_nodes.detach().clone() if self.n_nodes is not None else None)
    
    def duplicate(self, n : int):
        assert self.X.size(0) == 1, "duplicate: you can duplicate only if batch size == 1"
        X = self.X.repeat((n,1,1)) if self.X is not None else None
        E = self.E.repeat((n,1,1,1)) if self.E is not None else None
        y = self.y.repeat((n,1)) if self.y is not None else None
        c = self.charges.repeat((n,1,1)) if self.charges is not None else None
        p = self.pos.repeat((n,1,1)) if self.pos is not None else None
        g = self.guidance.repeat((n,1)) if self.guidance is not None else None

        node_mask = self.node_mask.repeat((n,1)) if self.node_mask is not None else None
        delt_mask = self.delt_mask.repeat((n,1,1)) if self.delt_mask is not None else None

        node_stats = self.node_stats.repeat((n,1)) if self.node_stats is not None else None
        edge_stats = self.edge_stats.repeat((n,1)) if self.edge_stats is not None else None
        charge_types= self.charge_types.repeat((n,1,1)) if self.charge_types is not None else None
        insert_time = self.insert_time.repeat((n,1,1))  if self.insert_time is not None else None

        n_nodes = self.n_nodes.repeat((n))  if self.n_nodes is not None else None

        t_int = self.t_int.repeat((n,1)) if self.t_int is not None else None

        to_return = PlaceHolder(X=X, E=E, y=y, charges=c, pos=p, guidance=g,
            t_int=t_int, 
            t=self.t.detach().clone() if self.t is not None else None,
            node_mask=node_mask,
            node_stats=node_stats,
            edge_stats=edge_stats,
            charge_types=charge_types,
            insert_time=insert_time,
            delt_mask=delt_mask,
            n_nodes=n_nodes)
        
        return to_return

    # This method takes a mask and removes the elements that are either
    # padding or True mask. Useful for removing, for instance, the DELt(s)
    def remove_elements(self, mask, only_mask=False):
        delete_mask = mask.clone()

        if(delete_mask.dim() == 3 and delete_mask.size(-1) == 1):
            delete_mask = delete_mask.squeeze(-1)

        assert self.node_mask.shape == delete_mask.shape, \
            f"self.node_mask.shape={self.node_mask.shape}," \
            f"delete_mask.shape={delete_mask.shape},"

        if(only_mask == False):
            delete_mask = delete_mask | ~self.node_mask

        delete_mask_int = delete_mask.int()
        #not super efficient since the delt mask is immediately
        #unsqueezed again inside mask() but whatever.
        self.mask(node_mask=~delete_mask, assert_E=False)

        # How many elements to remove (eg: if each element)
        pad_dim = torch.sum(delete_mask_int, dim=-1).min().item()

        ordered_mask, ordered_idx   = torch.sort(delete_mask_int, dim=1, stable=True)    
        ordered_idx                 = ordered_idx.unsqueeze(-1)    

        N, X_classes = self.X.size(1), self.X.size(-1)

        # We perform gather on the packed data. It's quicker
        gathered_tuple      = (self.X, self.insert_time, self.node_mask.unsqueeze(-1))
        
        if(self.delt_mask != None):
            gathered_tuple  = (*gathered_tuple, self.delt_mask)
        if self.charges != None: 
            gathered_tuple  = (*gathered_tuple, self.charges)
            c_classes       = self.charges.size(-1)
        if self.pos != None:
            gathered_tuple  = (*gathered_tuple, self.pos)
        
        gathered            = torch.cat(gathered_tuple, dim=-1)
        gathered_idx        = ordered_idx.expand(-1,-1,gathered.size(-1))
        gathered            = torch.gather(gathered, dim=1, index=gathered_idx)
        
        # print("self.X BEFORE sort:\n", self.X)
        self.X              = gathered[..., :X_classes]
        self.insert_time    = gathered[..., X_classes].unsqueeze(-1)
        self.node_mask      = gathered[..., X_classes+1].squeeze(-1).bool()

        s                   = X_classes+2
        if self.delt_mask != None:
            self.delt_mask   = gathered[..., s].unsqueeze(-1).bool()
            s              +=1
        if self.charges != None: 
            self.charges    = gathered[..., s:s+c_classes]
            s              += c_classes
        if self.pos != None:
            self.pos        = gathered[..., s:s+3]
            s              += 3

        # Ordering the edge matrix is more involved
        E_classes = self.E.size(-1)
        idx1 = ordered_idx.unsqueeze(-1).expand(-1,-1, N, E_classes)
        idx2 = ordered_idx.unsqueeze( 1).expand(-1, N,-1, E_classes)

        self.E = torch.gather(self.E, dim=1, index=idx1)
        self.E = torch.gather(self.E, dim=2, index=idx2)

        # In the remote case where pad_dim == max_n_nodes (eg: all graphs made
        # entirely by DELt and pad), it means that if we were to remove all nodes 
        # from it, it would leave us with an empty graph. This is bad because
        # 1) we'd need to remove them and that's a pain
        # 2) they are still useful for training in case we give the model an
        #    empty graph and we want to start from it
        # > in that case, we reduce the padding size by one, to ensure that we
        #   keep at least 1 node
        max_n_nodes = self.X.size(-2)
        if(pad_dim == max_n_nodes):
            pad_dim -= 1
        if(pad_dim > 0):
            self.X              = self.X[:, :-pad_dim,            :]
            self.E              = self.E[:, :-pad_dim, :-pad_dim, :]
            self.insert_time    = self.insert_time[:, :-pad_dim, :]
            self.node_mask      = self.node_mask[:, :-pad_dim]
            
            if self.charges is not None:
                self.charges    = self.charges[:,:-pad_dim, :]
            if self.delt_mask is not None:
                self.delt_mask  = self.delt_mask[:, :-pad_dim, :]
            # TODO: pos

        # self.n_nodes = self.node_mask.sum(dim=-1)

    def insert_delt(self, n_elements):
        n_graphs = n_elements.size(0)

        assert n_graphs == self.X.size(0)

        max_n_elements = n_elements.max().item()
        subset_delt_mask = torch.arange(max_n_elements, device=self.X.device)[None, :] < n_elements[:, None]

        delt_X = torch.zeros((self.X.size(0), max_n_elements, self.X.size(-1)), device=self.X.device)
        delt_X[:,:,-1] = 1
        # No need to set the "non-DELt" to zero since self.remove_elements deals with it
        self.X = torch.cat((self.X, delt_X), dim=-2)

        delt_E_p1 = torch.zeros((self.E.size(0), self.E.size(1), max_n_elements, self.E.size(-1)),
                                device=self.E.device)
        delt_E_p1[:,:,:,-1] = 1
        delt_E_p2 = torch.zeros((self.E.size(0), max_n_elements, self.E.size(2)+max_n_elements, self.E.size(-1)), 
                                device=self.E.device)
        delt_E_p2[:,:,:,-1] = 1
        # delt_E_p1[subset_delt_mask] = delt_E_tensor
        # delt_E_p2 = delt_E_p1.clone()
        # delt_E_p1 = delt_E_p1.unsqueeze(1).repeat(1,self.E.size(1),1,1)
        # delt_E_p2 = delt_E_p2.unsqueeze(2).repeat(1,1,self.E.size(1) + max_n_elements,1)
        self.E = torch.cat((self.E, delt_E_p1), dim=-2)
        self.E = torch.cat((self.E, delt_E_p2), dim=-3)

        # No need to check whether these masks exist or not since this method is literally "insert_delt"
        self.node_mask = torch.cat((self.node_mask, subset_delt_mask), dim=1)

        #useless but it keeps mask() happy
        # TODO: can't we just put if(mask != None) in mask()?
        subset_delt_mask = subset_delt_mask.unsqueeze(-1) #we don't need it 2d anymore
        self.delt_mask = torch.cat((self.delt_mask, subset_delt_mask), dim=1)
        self.insert_time = torch.cat((self.insert_time, torch.zeros_like(subset_delt_mask)), dim=1) 

        if(self.charges != None):
            delt_c = torch.zeros((self.charges.size(0), max_n_elements, self.charges.size(-1)), 
                                  device=self.charges.device)
            delt_c[:,:,-1] = 1
            # delt_c_tensor = torch.zeros(self.charges.size(-1),device=self.charges.device)
            # delt_c_tensor[-1] = 1
            # delt_c[subset_delt_mask] = delt_c_tensor
            self.charges = torch.cat((self.charges, delt_c), dim=-2)
        #TODO: pos

        # Removes pad and DEL nodes
        # print(f"insert_delt part 1 self.delt_mask size: {self.delt_mask.size()}")
        self.remove_elements(~self.node_mask, only_mask=True)

def setup_wandb(cfg):
    config_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
    kwargs = {'name': cfg.general.name, 'project': f'MolDiffusion_{cfg.dataset["name"]}', 'config': config_dict,
              'settings': wandb.Settings(_disable_stats=True),
              'reinit': True, 'mode': cfg.general.wandb}
    wandb.init(**kwargs)
    wandb.save('*.txt')
    return cfg


def remove_mean_with_mask(x, node_mask):
    """ x: bs x n x d.
        node_mask: bs x n """
    assert node_mask.dtype == torch.bool, f"Wrong type {node_mask.dtype}"
    node_mask = node_mask.unsqueeze(-1)
    masked_max_abs_value = (x * (~node_mask)).abs().sum().item()
    assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high'
    N = node_mask.sum(1, keepdims=True)

    mean = torch.sum(x, dim=1, keepdim=True) / N
    x = x - mean * node_mask
    return x

###########################################################

import signal
from contextlib import contextmanager

class TimeoutException(Exception): pass

@contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)



def update_config_with_new_keys(cfg, saved_cfg):
    saved_general = saved_cfg.general
    saved_train = saved_cfg.train
    saved_model = saved_cfg.model

    for key, val in saved_general.items():
        OmegaConf.set_struct(cfg.general, True)
        with open_dict(cfg.general):
            if key not in cfg.general.keys():
                setattr(cfg.general, key, val)

    OmegaConf.set_struct(cfg.train, True)
    with open_dict(cfg.train):
        for key, val in saved_train.items():
            if key not in cfg.train.keys():
                setattr(cfg.train, key, val)

    OmegaConf.set_struct(cfg.model, True)
    with open_dict(cfg.model):
        for key, val in saved_model.items():
            if key not in cfg.model.keys():
                setattr(cfg.model, key, val)
    return cfg

###############################################################################
# FreeGress stuff
def rstrip1(s, c):
    return s[:-1] if s[-1]==c else s

from rdkit import Chem
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem import rdMolDescriptors, Crippen, AllChem

from didigress.metrics.properties import qed, penalized_logp
from torch_geometric.data import Batch
import torch.nn.functional as F
from torch_geometric.data import Data

import numpy as np
from collections import Counter

def clean_mol(mol, uncharge=False):
    if(isinstance(mol, str)):
        mol = Chem.MolFromSmiles(mol)

    Chem.RemoveStereochemistry(mol)
    if uncharge: mol = rdMolStandardize.Uncharger().uncharge(mol)
    Chem.SanitizeMol(mol)
    
    return mol

def graph2mol(data, atom_decoder, dataset_infos):
    cfg = dataset_infos.cfg
    data = Batch.from_data_list([data])

    #smonta la variabile "data" costruita sopra e ri-ottiene nodi ed archi
    dense_data = to_dense(data, dataset_info=dataset_infos).collapse()

    assert dense_data.X.size(0) == 1
    atom_types  = dense_data.X[0]
    edge_types  = dense_data.E[0]
    pos = None
    charges = None
    if cfg.features.use_3d: pos = dense_data.pos[0]
    if cfg.features.use_charges: charges = dense_data.charges[0]

    #Questi sono anche i metodi utilizzati quando calcolavamo mu/HOMO in qm9, quindi
    #possiamo fidarci del fatto che funzionino (e comunque sono piuttosto utilizzati
    #in quanto provengono da un paper piuttosto citato da cui hanno preso tutti lo
    #spezzone di codice)
    reconstructed_mol = Molecule(node_types=atom_types, edge_types=edge_types, 
        positions=pos, charges=charges, atom_decoder=atom_decoder,
        use_charges=cfg.features.use_charges, use_3d=cfg.features.use_3d, 
        charges_policy=cfg.features.charges_policy)
    return reconstructed_mol.rdkit_mol

def mol2graph(mol, types, bonds, i, original_smiles = None, estimate_guidance = True, build_with_charges = False):
    N = mol.GetNumAtoms()

    type_idx = []
    for atom in mol.GetAtoms():
        atom_symbol = atom.GetSymbol()

        #if there are some atoms we may want to keep track of the formal charges
        #(either positive, negative, or both)
        if(build_with_charges):
            #gets the atom's formal charge
            atom_charge = atom.GetFormalCharge()

            #if the charge is not neutral
            if(atom_charge != 0):
                #this is necessary, as the sign "+" is lost when converting
                #atom_charge > 0 to a string. If charge < 0, the "-" is already embedded
                sign = ""
                if(atom_charge > 0):
                    sign = "+"

                #if the charge is not neutral, then its string in
                #the "types" dictionary is of the form <atom_symbol><formal charge>
                actual_atom_symbol = atom_symbol + sign + str(atom_charge)
                
                #check if the actual_atom_symbol is in the types dictionary.
                #if present, it means that we want to keep track of that
                #non-neutral version of the atom. Otherwise, we do not keep
                #the molecule.
                if(actual_atom_symbol in types):
                    atom_symbol = actual_atom_symbol
                else:
                    return None
        
        type_idx.append(types[atom_symbol])

    row, col, edge_type = [], [], []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        edge_type += 2 * [bonds[bond.GetBondType()] + 1]

    if len(row) == 0:
        print("Number of rows = 0")
        return None

    x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float()
    y = torch.zeros(size=(1, 0), dtype=torch.float)
    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(edge_type, dtype=torch.long)
    edge_attr = F.one_hot(edge_type, num_classes=len(bonds) + 1).to(torch.float)

    perm = (edge_index[0] * N + edge_index[1]).argsort()
    edge_index = edge_index[:, perm]
    edge_attr = edge_attr[perm]

    #stime di plogp, mw, sas e logp
    guidance = None
    if(estimate_guidance):
        guidance = torch.zeros((1, 5))
        estimated_plogp = penalized_logp(mol)
        estimated_qed   = qed(original_smiles)
        estimated_mw    = rdMolDescriptors.CalcExactMolWt(mol)
        estimated_sas   = -1 #calculateScore(mol)
        estimated_logp  = Crippen.MolLogP(mol)
        
        guidance[0, 0] = estimated_plogp
        guidance[0, 1] = estimated_qed
        guidance[0, 2] = estimated_mw
        guidance[0, 3] = estimated_sas
        guidance[0, 4] = estimated_logp
    
    #questo è l'oggetto effettivo che viene poi usato durante il
    #training. Più in basso verrà salvato in un formato gradito da
    #pytorch per poter essere riutilizzato più volte
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 
                y=y, idx=i, guidance=guidance, smiles = original_smiles)

def real_s_to_scaled(t, s, scale, gap):
    # To convert a number x in range between [old_min, old_max] 
    # to something in range                  [new_min, new_max],
    # you compute the formula
    # new_x = new_min + (x - old_min)*(new_max - new_min)/(old_max - old_min)

    old_min = 0
    old_max = t
    new_min = scale*gap
    new_max = scale-scale*gap

    return new_min + (s - old_min)*(new_max - new_min)/(old_max - old_min)

def scaled_to_real_s(t, s, scale, gap):
    # To convert a number x in range between [old_min, old_max] 
    # to something in range                  [new_min, new_max],
    # you compute the formula
    # new_x = new_min + (x - old_min)*(new_max - new_min)/(old_max - old_min)

    # print(f"t={t}")
    # print(f"s={s}")
    # print(f"scale={scale}")
    # print(f"gap={gap}")

    old_min = scale*gap
    old_max = scale-scale*gap
    new_min = 0
    new_max = t

    return new_min + (s - old_min)*(new_max - new_min)/(old_max - old_min)

def get_atom_counts(smiles: str, predefined_atoms: set) -> dict:
        # Convert the SMILES string to an RDKit molecule object
        mol = Chem.MolFromSmiles(smiles)
        
        # Get a list of atomic symbols for all atoms in the molecule
        atom_symbols = [atom.GetSymbol() for atom in mol.GetAtoms()]
        
        # Use Counter to count the occurrences of each atom type
        atom_counts = Counter(atom_symbols)
        
        # Create a dictionary including predefined atoms with default count of 0
        full_atom_counts = {atom: atom_counts.get(atom, 0) for atom in predefined_atoms}
        
        return list(full_atom_counts.values())

def get_mol_fingerprint(mol):
    estimated_fprint= AllChem.GetMorganFingerprintAsBitVect(mol,2)#ecfp4
    estimated_fprint= np.frombuffer(estimated_fprint.ToBitString().encode('utf-8'), 'u1') - ord('0')
    estimated_fprint= torch.from_numpy(estimated_fprint)

    return estimated_fprint