import os
import os.path as osp
import pathlib
from typing import Any, Sequence


from didigress.metrics.properties import qed, drd2, penalized_logp
import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from hydra.utils import get_original_cwd

from didigress.datasets.dataset_utils import mol_to_torch_geometric, remove_hydrogens, StatisticsMolecule
from didigress.datasets.dataset_utils import load_pickle, save_pickle
from didigress.datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule, AbstractAdaptiveDataModule
from didigress.metrics.metrics_utils import compute_all_statistics, atom_type_counts, edge_counts, charge_counts
from didigress.utils import PlaceHolder

from didigress.datasets.abstract_dataset import AbstractDatasetInfos

from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem import rdMolDescriptors, Crippen

from didigress.utils import rstrip1, clean_mol, graph2mol

def files_exist(files) -> bool:
    # NOTE: We return `False` in case `files` is empty, leading to a
    # re-processing of files on every instantiation.
    return len(files) != 0 and all([osp.exists(f) for f in files])

def to_list(value: Any) -> Sequence:
    if isinstance(value, Sequence) and not isinstance(value, str):
        return value    
    else:
        return [value]

class RemoveYTransform:
    def __call__(self, data):
        data.guidance = torch.zeros((1, 0), dtype=torch.float)
        return data

class FeatureExtractorTransform:
    def __init__(self, guidance_target):
        self.guidance_target = guidance_target

    def __call__(self, data):
        guidance = []
        if 'penalizedlogp' in self.guidance_target:
            guidance.append(data.guidance[..., 0:1])
        if 'qed' in self.guidance_target:
            guidance.append(data.guidance[..., 1:2])
        if 'mw' in self.guidance_target:
            guidance.append(data.guidance[..., 2:3] / 100)
        if 'sas' in self.guidance_target:
            guidance.append(data.guidance[..., 3:4])
        if 'logp' in self.guidance_target:
            guidance.append(data.guidance[..., 4:5])
        if 'drd2' in self.guidance_target:
            guidance.append(data.guidance[..., 5:6])
        
        data.guidance = torch.hstack(tuple(guidance))

        return data
    
class DummyDatasetInfos():
    def __init__(self, cfg, atom_decoder):
        self.cfg=cfg
        self.num_node_types = len(atom_decoder)

    def to_one_hot(self, X,  E, node_mask, charges = None):
        X = F.one_hot(X, num_classes=self.num_node_types).float()
        E = F.one_hot(E, num_classes=5).float()

        if(self.cfg.features.use_charges):
            charges = F.one_hot(charges + 1, num_classes=5).float()
        
        placeholder = PlaceHolder(X=X, E=E, y=None, charges=charges, pos=None)
        pl = placeholder.mask(node_mask)
        return pl.X, pl.E, pl.charges


full_atom_encoder = {'H' : 0, 'C' : 1, 'N' : 2, 'O' : 3, 'F' : 4, 'B' : 5, 
    'Br' : 6, 'Cl' : 7, 'I' : 8, 'P' : 9, 'S' : 10, 'N+1' : 11, 'O-1' : 12}

class ZINC250KDataset(InMemoryDataset):
    raw_url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
    #raw_url = "https://github.com/aspuru-guzik-group/chemical_vae/blob/main/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"

    def __init__(self, split, root, remove_h: bool, cfg, transform=None, pre_transform=None, pre_filter=None):
        self.split          = split
        if self.split == 'train':
            self.file_idx = 0
        elif self.split == 'val':
            self.file_idx = 1
        else:
            self.file_idx = 2
        self.remove_h = remove_h

        self.cfg = cfg
        self.use_charges = self.cfg.features.use_charges
        self.charges_policy = cfg.features.charges_policy

        self.atom_encoder = full_atom_encoder
        # Removes N+1 and O-1 if necessary
        if self.use_charges or self.charges_policy != 'dictionary':
            self.atom_encoder.popitem()
            self.atom_encoder.popitem()
        
        if remove_h:
            self.atom_encoder = {k: v - 1 for k, v in self.atom_encoder.items() if k != 'H'}

        # Dataset infos has this as well but we need it in process()
        self.atom_decoder = [key for key in self.atom_encoder.keys()]

        super().__init__(root, transform, pre_transform, pre_filter)
        load_index = 0
        # If we are running an optimization experiment, loads the right dataset
        # By setting the correct load_index
        if self.file_idx == 2 and self.cfg.guidance.experiment_type == 'optimization':
            if self.cfg.guidance.guidance_target[0] == 'logp':
                load_index = 7
            elif self.cfg.guidance.guidance_target[0] == 'qed':
                load_index = 8
            elif self.cfg.guidance.guidance_target[0] == 'drd2':
                load_index = 9
        self.data, self.slices = torch.load(self.processed_paths[load_index])


        valencies=load_pickle(self.processed_paths[5])

        charge_types=torch.from_numpy(np.load(self.processed_paths[4], allow_pickle=True)).float()
        
        self.statistics = StatisticsMolecule(num_nodes=load_pickle(self.processed_paths[1]),
                                     node_types=torch.from_numpy(np.load(self.processed_paths[2])).float(),
                                     edge_types=torch.from_numpy(np.load(self.processed_paths[3])).float(),
                                     cfg=self.cfg,
                                     charge_types=charge_types,
                                     valencies=valencies,
                                     bond_lengths=None,
                                     bond_angles=None)
        self.smiles = load_pickle(self.processed_paths[6])

    @property
    def raw_file_names(self):
        #return ['train_zinc250k.csv', 'val_zinc250k.csv', 'test_zinc250k.csv']
        return ['250k_rndm_zinc_drugs_clean_3.csv']

    @property
    def split_file_name(self):
        return ['train.csv', 'val.csv', 'test.csv']

    @property
    def split_paths(self):
        r"""The absolute filepaths that must be present in order to skip
        splitting."""
        files = to_list(self.split_file_name)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def processed_file_names(self):
        h = 'noh' if self.remove_h else 'h'
        if self.split == 'train':
            return [f'train_{h}.pt', f'train_n_{h}.pickle', f'train_node_types_{h}.npy', f'train_edge_types_{h}.npy',
                    f'train_charges_{h}.npy', f'train_valency_{h}.pickle', 'train_smiles.pickle']
        elif self.split == 'val':
            return [f'val_{h}.pt', f'val_n_{h}.pickle', f'val_node_types_{h}.npy', f'val_edge_types_{h}.npy',
                    f'val_charges_{h}.npy', f'val_valency_{h}.pickle', 'val_smiles.pickle']
        else:
            return [f'test_{h}.pt', f'test_n_{h}.pickle', f'test_node_types_{h}.npy', f'test_edge_types_{h}.npy',
                    f'test_charges_{h}.npy', f'test_valency_{h}.pickle', 'test_smiles.pickle',
                    f'test_logp_{h}.pt', f'test_qed_{h}.pt', f'test_drd2_{h}.pt']
        
    def download(self):
        """
        Download raw zinc files
        """
        try:
            import rdkit  # noqa
            file_path = download_url(self.raw_url, self.raw_dir)
            #extract_zip(file_path, self.raw_dir)
            #os.unlink(file_path)
        except ImportError:
            print("Failed to download the dataset...")

        if files_exist(self.split_paths):
            return

        #dataset = pd.read_csv(self.raw_paths[1])
        dataset = pd.read_csv(file_path)

        # dataset = dataset.head(1200)
        # dataset = dataset.iloc[32520:33720]

        n_samples = len(dataset)
        n_train = int(0.8 * n_samples)
        n_test = int(0.1 * n_samples)
        n_val = n_samples - (n_train + n_test)

        # n_train = 1000
        # n_test = 100
        # n_val = 100

        # Shuffle dataset with df.sample, then split
        train, val, test = np.split(dataset.sample(frac=1, random_state=42), [n_train, n_val + n_train])

        train.to_csv(os.path.join(self.raw_dir, 'train.csv'))
        val.to_csv(os.path.join(self.raw_dir, 'val.csv'))
        test.to_csv(os.path.join(self.raw_dir, 'test.csv'))
    
    def process(self):
        RDLogger.DisableLog('rdApp.*')
        
        if(self.file_idx != 0):
            root_dir                 = pathlib.Path(os.path.realpath(__file__)).parents[2]
            train_path               = "/data/zinc250k/raw/new_train.smiles"
            smiles_path              = str(root_dir) + train_path
            my_file                  = open(smiles_path, "r")
            data                     = my_file.read()
            train_smiles             = data.split("\n")

        path = self.split_paths[self.file_idx]
        print("path: ", path)
        target_df = pd.read_csv(path)
        
        smiles_list = target_df['smiles'].values
        
        data_list       = []
        smiles_kept     = []

        all_smiles = []

        unmatches = 0
        exceptions = 0
        string_unmatches = 0
        n_already_present = 0
        num_errors = 0

        print("Initial size: ", len(smiles_list))

        uncharge = self.cfg.features.charges_policy == "partial"
        #Qui è dove il VERO preprocessing avviene. Essenzialmente prende gli smiles
        #del training/val/test set (dipende da chi chiama il metodo) e processa
        #lo smiles per ottenere grafo + guida
        for i, original_smile in enumerate(tqdm(smiles_list)):
            #rimuove \n
            smiles_2D = rstrip1(original_smile, "\n")

            #otteniamo la molecola dallo smiles_2D ed eseguiamo il preprocessing
            mol = clean_mol(smiles_2D, uncharge)
            smiles_2D = Chem.MolToSmiles(mol)
            if smiles_2D is None:
                num_errors += 1

            # data = mol2graph(mol, self.types, self.bonds, i, smiles_2D, True, build_with_charges)
            data = mol_to_torch_geometric(mol=mol, atom_encoder=full_atom_encoder, smiles=smiles_2D, cfg=self.cfg)
            
            #questo succede se abbiamo un grafo vuoto (può succedere)
            #o se un atomo conteneva cariche non neutrali quando non
            #volevamo tenerne
            if(data == None):
                continue

            if self.remove_h:
                data = remove_hydrogens(data=data,cfg=self.cfg)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            guidance = torch.zeros((1, 6))
            estimated_plogp = penalized_logp(mol)
            estimated_qed   = qed(smiles_2D)
            estimated_mw    = rdMolDescriptors.CalcExactMolWt(mol)
            estimated_sas   = -1 #calculateScore(mol)
            estimated_logp  = Crippen.MolLogP(mol)
            estimated_drd2  = drd2(smiles_2D)
            
            guidance[0, 0] = estimated_plogp
            guidance[0, 1] = estimated_qed
            guidance[0, 2] = estimated_mw
            guidance[0, 3] = estimated_sas
            guidance[0, 4] = estimated_logp
            guidance[0, 5] = estimated_drd2
            data.guidance  = guidance
            data.n_nodes   = data.x.size(0)

            node_stats   = atom_type_counts(data_list=[data], num_classes=len(self.atom_encoder), verbose=False)
            edge_stats   = edge_counts(data_list=[data], verbose=False)
            if(np.isnan(edge_stats).all()):
                print("found a NaN edge tensor")
                edge_stats = np.zeros_like(edge_stats)
                edge_stats[0] = 1 
            charge_types = charge_counts(data_list=[data], num_classes=len(self.atom_encoder), 
                                        charges_dic={-2: 0, -1: 1, 0: 2, 1: 3, 2: 4}, verbose=False)
            
            data.node_stats = torch.from_numpy(node_stats).unsqueeze(0)
            data.edge_stats = torch.from_numpy(edge_stats).unsqueeze(0)
            data.charge_types = torch.from_numpy(charge_types).unsqueeze(0)

            #controllo sanità del grafo generato. Qui viene effettivamente 
            #ricostruito il Chem.RWMol della molecola usando la variabile "data" 
            #sopra contenente il grafo della molecola e ne controlla la qualità 
            #(es: se si è spezzata o meno.). Se passa i test, finisce nel training/val/test set.
            #=> PERFETTO PER VEDERE SE LOGP ecc CAMBIANO UNA VOLTA RIASSEMBLATA LA MOLECOLA
            #   PARTENDO DAL GRAFO.
            # DummyDatasetInfos is, like, the worst thing ever from a software
            # engineering point of view since it replaces the true dataset infos
            # which is not yet instantiated at this point (actually, it uses 
            # this own class as constructor parameter). But the only thing in graph2mol
            # that uses it is to_dense, which in turn uses only to_one_hot, so
            # it's no biggie
            dummy_dataset_infos      = DummyDatasetInfos(cfg=self.cfg, atom_decoder=self.atom_decoder)
            reconstructed_mol        = graph2mol(data, dataset_infos=dummy_dataset_infos,
                                                 atom_decoder=self.atom_decoder)

            try:
                reconstructed_mol    = clean_mol(reconstructed_mol, uncharge)
                reconstructed_smiles = Chem.MolToSmiles(reconstructed_mol)
            except:
                exceptions = exceptions + 1
                continue

            #Skips if an element of the test set was also in the training set
            if(self.file_idx != 0):
                if(reconstructed_smiles in train_smiles):
                    n_already_present += 1
                    continue

            ###############################################################
            #se gli smiles della molecola ricostruita e quello originale 
            #sono identici, non è neanche necessario controllarne le proprietà.
            #Se differiscono, potrebbero differirne anche le proprietà. Vediamo un po...
            if(smiles_2D == reconstructed_smiles):
                tmp_mol_original_smile       = mol
                tmp_mol_smiles_2D            = reconstructed_mol
                print_smiles                 = False

                tmp_mol_original_smile_qed   = qed(tmp_mol_original_smile)
                tmp_mol_original_smile_logp  = Crippen.MolLogP(tmp_mol_original_smile)
                tmp_mol_original_smile_mw    = rdMolDescriptors.CalcExactMolWt(tmp_mol_original_smile)
                tmp_mol_original_smile_dr2d  = drd2(smiles_2D)
                
                tmp_mol_smiles_2D_qed        = qed(tmp_mol_smiles_2D)
                tmp_mol_smiles_2D_logp       = Crippen.MolLogP(tmp_mol_smiles_2D)
                tmp_mol_smiles_2D_mw         = rdMolDescriptors.CalcExactMolWt(tmp_mol_smiles_2D)
                tmp_mol_smiles_2D_dr2d       = drd2(reconstructed_smiles)
                
                if(abs(tmp_mol_original_smile_qed - tmp_mol_smiles_2D_qed) > 1e-5):
                    print_smiles=True
                if(abs(tmp_mol_original_smile_logp - tmp_mol_smiles_2D_logp) > 1e-5):
                    print_smiles=True
                if(abs(tmp_mol_original_smile_mw - tmp_mol_smiles_2D_mw) > 4):
                    print_smiles=True
                if(abs(tmp_mol_original_smile_dr2d - tmp_mol_smiles_2D_dr2d) > 1e-5):
                    print_smiles=True
                
                if(print_smiles):
                    reconstructed_smiles = None
                    unmatches            = unmatches + 1
                    num_errors          += 1
            else:
                string_unmatches = string_unmatches + 1
                num_errors      += 1

            ###############################################################

            if reconstructed_smiles is not None and smiles_2D == reconstructed_smiles: 
                try:
                    mol_frags = Chem.rdmolops.GetMolFrags(reconstructed_mol, asMols=True, sanitizeFrags=True)
                    if len(mol_frags) == 1:
                        data_list.append(data)
                        smiles_kept.append(reconstructed_smiles)

                        #DEBUG
                        all_smiles.append(reconstructed_smiles)

                except Chem.rdchem.AtomValenceException:
                    print("Valence error in GetmolFrags")
                except Chem.rdchem.KekulizeException:
                    print("Can't kekulize molecule")
        
        print("data_list size:", len(data_list))
        torch.save(self.collate(data_list), self.processed_paths[self.file_idx])

        print("Total unmatches: ", unmatches)
        print("exceptions =", exceptions)
        print("string_unmatches: ", string_unmatches)
        print("removed because already present in the training set: ", n_already_present)

        smiles_save_path = osp.join(pathlib.Path(self.raw_paths[0]).parent, f'new_{self.split}.smiles')
        print(smiles_save_path)
        with open(smiles_save_path, 'w') as f:
            f.writelines('%s\n' % s for s in smiles_kept)
        print(f"Number of molecules kept: {len(smiles_kept)} / {len(smiles_list)}")

        statistics = compute_all_statistics(data_list, self.atom_encoder, charges_dic={-1: 0, 0: 1, 1: 2}, cfg=self.cfg, use_3d=False)

        save_pickle(statistics.num_nodes, self.processed_paths[1])
        np.save(self.processed_paths[2], statistics.node_types)
        np.save(self.processed_paths[3], statistics.edge_types)
        np.save(self.processed_paths[4], statistics.charge_types)
        save_pickle(statistics.valencies, self.processed_paths[5])
        print("Number of molecules that could not be mapped to smiles: ", num_errors)
        save_pickle(set(all_smiles), self.processed_paths[6])
        torch.save(self.collate(data_list), self.processed_paths[0])

        test_size=800
        if self.file_idx == 2:
            extracted_guidance = torch.zeros((len(data_list), 6))
            for i in range(len(data_list)):
                extracted_guidance[i, :] = data_list[i].guidance
            
            # Extracts from data_list the test_size entries with the lowest logP
            # and saves them in a new file
            logp_data = extracted_guidance[:, 4:5]
            logp_indexes = torch.argsort(logp_data, dim=0).squeeze(-1)
            data_list_logp = [data_list[j] for j in logp_indexes]
            data_list_logp = data_list_logp[:test_size]

            # Extracts test_size entries from data_list where QED is within 0.7 and 0.8
            qed_data = extracted_guidance[:, 1:2]
            qed_indexes = ((qed_data >= 0.7) & (qed_data <= 0.8)).squeeze(-1).nonzero()
            qed_indexes = qed_indexes[:test_size]
            data_list_qed = [data_list[j] for j in qed_indexes]

            # Extracts test_size entries from data_list where drd2 is less than 0.05
            drd2_data = extracted_guidance[:, 5:6]
            drd2_indexes = (drd2_data < 0.05).squeeze(-1).nonzero()
            drd2_indexes = drd2_indexes[:test_size]
            data_list_drd2 = [data_list[j] for j in drd2_indexes]

            # Saves the three datasets new files
            torch.save(self.collate(data_list_logp), self.processed_paths[7])
            torch.save(self.collate(data_list_qed), self.processed_paths[8])
            torch.save(self.collate(data_list_drd2), self.processed_paths[9])

            def save_smiles(input_smiles, file_name):
                smiles_save_path = osp.join(pathlib.Path(self.raw_paths[0]).parent, file_name)
                print(smiles_save_path)
                with open(smiles_save_path, 'w') as f:
                    f.writelines(f'{input_smiles[s].smiles}\n' for s in range(len(input_smiles)))

            save_smiles(data_list_logp, 'logp_smiles.smiles')
            save_smiles(data_list_qed, 'qed_smiles.smiles')
            save_smiles(data_list_drd2, 'drd2_smiles.smiles')


class ZINC250KDataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(get_original_cwd()).parents[0]
        root_path = os.path.join(base_path, self.datadir)

        target = cfg.guidance.guidance_target
        if len(set(['logp', 'qed', 'mw', 'drd2']).intersection(target)) > 0:
            transform = FeatureExtractorTransform(cfg.guidance.guidance_target)
        else:
            transform = RemoveYTransform()

        train_dataset = ZINC250KDataset(split='train', root=root_path, remove_h=cfg.dataset.remove_h, cfg=cfg, transform=transform)
        val_dataset = ZINC250KDataset(split='val', root=root_path, remove_h=cfg.dataset.remove_h, cfg=cfg, transform=transform)
        test_dataset = ZINC250KDataset(split='test', root=root_path, remove_h=cfg.dataset.remove_h, cfg=cfg, transform=transform)
        self.statistics = {'train': train_dataset.statistics, 'val': val_dataset.statistics,
                           'test': test_dataset.statistics}
        self.remove_h = cfg.dataset.remove_h
        super().__init__(cfg, train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset)


class ZINC250Kinfos(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg):
        self.remove_h = cfg.dataset.remove_h
        self.statistics = datamodule.statistics
        self.name = 'ZINC250K'
        self.atom_encoder = full_atom_encoder
        self.collapse_charges = torch.Tensor([-2, -1, 0, 1, 2]).int()
        if self.remove_h:
            self.atom_encoder = {k: v - 1 for k, v in self.atom_encoder.items() if k != 'H'}
        super().complete_infos(datamodule.statistics, self.atom_encoder, cfg)
        
        use_conditional = cfg.guidance.p_uncond >= 0
        medium_extra = {'X': 0, 'E': 0, 'y': 0, 'p': 0,}
        guidance_sizes = {'logp': 1, 'mw': 1, 'qed': 1, 'drd2': 1}
        self.guidance_dims = 0

        if(use_conditional):
            guidance_mediums = cfg.guidance.guidance_medium
            targets = cfg.guidance.guidance_target

            for target in targets:
                self.guidance_dims += guidance_sizes[target]
                for medium in guidance_mediums:
                    medium_extra[medium] = medium_extra[medium] + guidance_sizes[target]

        print("self.guidance_dims =", self.guidance_dims)
        print("medium_extra =", medium_extra)

        self.use_ins_del= cfg.features.use_ins_del
        IN_ins_del_sz = 2 if self.use_ins_del else 0

        X_sz_out = self.num_node_types + IN_ins_del_sz
        E_sz_out = 5 + IN_ins_del_sz
        y_sz_out = 0
        c_sz_out = 5 + IN_ins_del_sz
        p_sz_out = 3 + IN_ins_del_sz

        X_sz = X_sz_out + medium_extra['X']
        E_sz = E_sz_out + medium_extra['E']
        y_sz = y_sz_out + medium_extra['y'] + 1 #The "+1" is the timestep t
        c_sz = c_sz_out 
        p_sz = p_sz_out + medium_extra['p']

        #TODO: add INS/DEL categories
        self.input_dims = PlaceHolder(X=X_sz, charges=c_sz, E=E_sz, y=y_sz, pos=p_sz, guidance=self.guidance_dims)
        self.output_dims = PlaceHolder(X=X_sz_out, charges=c_sz_out, E=E_sz_out, y=0, pos=p_sz_out, guidance=0)

        self.cfg = cfg
        self.charges_dic = {-2: 0, -1: 1, 0: 2, 1: 3, 2: 4}

    def to_one_hot(self, X,  E, node_mask, charges = None):
        X = F.one_hot(X, num_classes=self.num_node_types).float()
        E = F.one_hot(E, num_classes=5).float()

        if(self.cfg.features.use_charges):
            charges = F.one_hot(charges + 1, num_classes=5).float()
        
        placeholder = PlaceHolder(X=X, E=E, y=None, charges=charges, pos=None)
        pl = placeholder.mask(node_mask)
        return pl.X, pl.E, pl.charges

    def one_hot_charges(self, charges):
        return F.one_hot((charges + 1).long(), num_classes=5).float()
    


"""
import hydra
import omegaconf
from omegaconf import DictConfig
@hydra.main(version_base='1.3', config_path='../../configs', config_name='config')
def main(cfg: DictConfig):
    #ds = [ZINC250KDataset(s, os.path.join(os.path.abspath(__file__), "../../data/zinc250k")) for s in ["train", "val", "test"]]
    print(cfg)

    cfg.dataset.name                        = 'zinc250k' 
    cfg.dataset.datadir                     = 'data/zinc250k/zinc250k_pyg/'
    cfg.dataset.remove_h                    =  True
    cfg.dataset.random_subset               = None
    cfg.dataset.pin_memory                  = False
    cfg.dataset.filter                      = True
    cfg.guidance.build_with_partial_charges = "new_method"

    datamodule = ZINC250KDataModule(cfg)
    dataset_infos = ZINC250Kinfos(datamodule, cfg, recompute_statistics = True)

if __name__ == '__main__':
    main()
"""