import torch
import numpy as np

import logging
import os

from torch.utils.data import DataLoader
from qm9.data.dataset_class import ProcessedDataset
from qm9.data.prepare import prepare_dataset


def initialize_datasets(args, datadir, dataset, subset=None, splits=None,
                        force_download=False, subtract_thermo=False,
                        remove_h=False):
    """
    Initialize datasets.

    Parameters
    ----------
    args : dict
        Dictionary of input arguments detailing the cormorant calculation.
    datadir : str
        Path to the directory where the data and calculations and is, or will be, stored.
    dataset : str
        String specification of the dataset.  If it is not already downloaded, must currently by "qm9" or "md17".
    subset : str, optional
        Which subset of a dataset to use.  Action is dependent on the dataset given.
        Must be specified if the dataset has subsets (i.e. MD17).  Otherwise ignored (i.e. GDB9).
    splits : str, optional
        TODO: DELETE THIS ENTRY
    force_download : bool, optional
        If true, forces a fresh download of the dataset.
    subtract_thermo : bool, optional
        If True, subtracts the thermochemical energy of the atoms from each molecule in GDB9.
        Does nothing for other datasets.
    remove_h: bool, optional
        If True, remove hydrogens from the dataset
    Returns
    -------
    args : dict
        Dictionary of input arguments detailing the cormorant calculation.
    datasets : dict
        Dictionary of processed dataset objects (see ????? for more information).
        Valid keys are "train", "test", and "valid"[ate].  Each associated value
    num_species : int
        Number of unique atomic species in the dataset.
    max_charge : pytorch.Tensor
        Largest atomic number for the dataset.

    Notes
    -----
    TODO: Delete the splits argument.
    """
    # Set the number of points based upon the arguments
    num_pts = {'train': args.num_train,
               'test': args.num_test, 'valid': args.num_valid}

    # Download and process dataset. Returns datafiles.
    datafiles = prepare_dataset(
        datadir, dataset, subset, splits, force_download=force_download, use_vocab_data=args.use_vocab_data)

    process_remove_h = True
    if remove_h:
        datafiles_without_h = {split: datafile.replace('.npz', '_without_h.npz') for split, datafile in datafiles.items()}
        process_remove_h = not all([os.path.exists(datafile) for datafile in datafiles_without_h.values()])
        if not process_remove_h:
            datafiles = datafiles_without_h

    # Load downloaded/processed datasets
    datasets = {}
    for split, datafile in datafiles.items():
        with np.load(datafile) as f:
            datasets[split] = {key: torch.from_numpy(val) for key, val in f.items()}

    if args.use_vocab_data:
        # add vocab data twice because it's too little
        for _ in range(2):
            datasets['train'] = {key: combine(datasets['train'][key], datasets['vocab'][key]) for key in datasets['train']}

    if args.use_ghost_nodes:
        print('Making the number of nodes consistent across splits. Necessary for VAE evaluation on validation dataset')
        max_n_nodes = max([datasets[split]['positions'].size(1) for split in datafiles])
        print(f'Max number of nodes across splits is {max_n_nodes}')
        for split in datafiles:
            n_nodes = datasets[split]['positions'].size(1)
            if n_nodes < max_n_nodes:
                n_nodes_to_pad = max_n_nodes - n_nodes
                n_samples = datasets[split]['positions'].size(0)
                datasets[split]['positions'] = torch.cat((datasets[split]['positions'], torch.zeros((n_samples, n_nodes_to_pad, 3))), dim=1)
                datasets[split]['atomic_numbers'] = torch.cat((datasets[split]['atomic_numbers'], torch.zeros((n_samples, n_nodes_to_pad))), dim=1)
                datasets[split]['formal_charges'] = torch.cat((datasets[split]['formal_charges'], torch.zeros((n_samples, n_nodes_to_pad))), dim=1)

    if dataset in ['qm9_second_half', 'qm9_first_half']:
        np.random.seed(42)
        fixed_perm = np.random.permutation(len(datasets['train']['num_atoms']))
        if dataset == 'qm9_second_half':
            sliced_perm = fixed_perm[len(datasets['train']['num_atoms'])//2:]
        elif dataset == 'qm9_first_half':
            sliced_perm = fixed_perm[0:len(datasets['train']['num_atoms']) // 2]
        else:
            raise Exception('Wrong dataset name')
        for key in datasets['train']:
            datasets['train'][key] = datasets['train'][key][sliced_perm]

    # Basic error checking: Check the training/test/validation splits have the same set of keys.
    keys = [list(data.keys()) for data in datasets.values()]
    assert all([key == keys[0] for key in keys]
               ), 'Datasets must have same set of keys!'

    # TODO: remove hydrogens here if needed
    if remove_h and process_remove_h:
        for key, dataset in datasets.items():
            pos = dataset['positions']
            atomic_numbers = dataset['atomic_numbers']
            num_atoms = dataset['num_atoms']
            formal_charges = dataset['formal_charges']

            # Check that atomic_numbers corresponds to real atoms
            assert torch.sum(num_atoms != torch.sum(atomic_numbers > 0, dim=1)) == 0

            mask = dataset['atomic_numbers'] > 1 # indices of non-Hydrogen atoms
            new_positions = torch.zeros_like(pos)
            new_atomic_numbers = torch.zeros_like(atomic_numbers)
            new_formal_charges = torch.zeros_like(formal_charges)
            new_adj_lists = torch.zeros_like(dataset['adj_list'])
            for i in range(new_positions.shape[0]):
                m = mask[i]
                p = pos[i][m]   # positions to keep
                p = p - torch.mean(p, dim=0)    # Center the new positions
                c = atomic_numbers[i][m]   # atomic_numbers to keep
                formal_c = formal_charges[i][m] # Formal charges to keep
                n = torch.sum(m)
                new_positions[i, :n, :] = p
                new_atomic_numbers[i, :n] = c
                new_formal_charges[i, :n] = formal_c

                # adj_list contains all edges as rows (atom_i_idx, atom_j_idx, bond_type) 
                # hydrogens have the highest index by construction, so we set the rows where they appear as indices to 0.
                adj_list = dataset['adj_list'][i]
                adj_list[adj_list[:,0] > n-1] = 0
                adj_list[adj_list[:,1] > n-1] = 0
                # remove all-0-rows
                adj_list = adj_list[adj_list.sum(1) != 0]
                new_adj_lists[i][:len(adj_list)] = adj_list

            dataset['positions'] = new_positions
            dataset['atomic_numbers'] = new_atomic_numbers
            dataset['num_atoms'] = torch.sum(dataset['atomic_numbers'] > 0, dim=1)
            dataset['formal_charges'] = new_formal_charges
            dataset['adj_list'] = new_adj_lists

            logging.info('Saving data without hydrogens:')
            np.savez_compressed(datafiles_without_h[key], **dataset)
            logging.info('Processing/saving complete!')

    # Get a list of all species across the entire dataset
    all_species = _get_species(datasets, ignore_check=False, use_ghost_nodes=args.use_ghost_nodes)
    # this is to solve the issue for guacamol where some rare formal charges only appear in trianing split
    # but we still want to build the same dimensional one-hot vector
    all_formal_charges = torch.cat([dataset['formal_charges'].unique()
                             for dataset in datasets.values()]).unique(sorted=True)

    # Now initialize MolecularDataset based upon loaded data
    datasets = {split: ProcessedDataset(data, num_pts=num_pts.get(
        split, -1), included_species=all_species, included_formal_charges=all_formal_charges, subtract_thermo=subtract_thermo, use_ghost_nodes=args.use_ghost_nodes) for split, data in datasets.items()}

    # Check that all datasets have the same included species:
    assert(len(set(tuple(data.included_species.tolist()) for data in datasets.values())) ==
           1), 'All datasets must have same included_species! {}'.format({key: data.included_species for key, data in datasets.items()})

    # These parameters are necessary to initialize the network
    num_species = datasets['train'].num_species
    max_charge = datasets['train'].max_charge

    # Now, update the number of training/test/validation sets in args
    args.num_train = datasets['train'].num_pts
    args.num_valid = datasets['valid'].num_pts
    args.num_test = datasets['test'].num_pts

    return args, datasets, num_species, max_charge


def _get_species(datasets, ignore_check=False, use_ghost_nodes=False):
    """
    Generate a list of all species.

    Includes a check that each split contains examples of every species in the
    entire dataset.

    Parameters
    ----------
    datasets : dict
        Dictionary of datasets.  Each dataset is a dict of arrays containing molecular properties.
    ignore_check : bool
        Ignores/overrides checks to make sure every split includes every species included in the entire dataset

    Returns
    -------
    all_species : Pytorch tensor
        List of all species present in the data.  Species labels shoguld be integers.

    """
    # Get a list of all species in the dataset across all splits
    all_species = torch.cat([dataset['atomic_numbers'].unique()
                             for dataset in datasets.values()]).unique(sorted=True)

    # Find the unique list of species in each dataset.
    split_species = {split: species['atomic_numbers'].unique(
        sorted=True) for split, species in datasets.items()}

    # If zero atomic_numbers (padded, non-existent atoms) are included, remove them
    # : only do so if we're not using ghost nodes!
    if not use_ghost_nodes and all_species[0] == 0:
        all_species = all_species[1:]

    # Remove zeros if zero-padded atomic_numbers exst for each split
    # : only do so if we're not using ghost nodes!
    if not use_ghost_nodes:
        split_species = {split: species[1:] if species[0] ==
                        0 else species for split, species in split_species.items()}

    # Now check that each split has at least one example of every atomic spcies from the entire dataset.
    if not all([split.tolist() == all_species.tolist() for split in split_species.values()]):
        # Allows one to override this check if they really want to. Not recommended as the answers become non-sensical.
        if ignore_check:
            logging.error(
                'The number of species is not the same in all datasets!')
        else:
            raise ValueError(
                'Not all datasets have the same number of species!')

    # Finally, return a list of all species
    return all_species

def combine(data1, data2):
    assert len(data1.size()) == len(data2.size())
    if len(data1.size()) == 1:
        return torch.cat([data1, data2], dim=0)
    n_padding = data1.size(1) - data2.size(1)
    assert n_padding > 0, "already equal or wrong order"
    if len(data1.size()) == 2:
        data2_padded = torch.cat([data2, torch.zeros(data2.size(0), n_padding)], dim=1)
        return torch.cat([data1, data2_padded], dim=0)
    elif len(data1.size()) == 3:
        data2_padded = torch.cat([data2, torch.zeros(data2.size(0), n_padding, data2.size(2))], dim=1)
        return torch.cat([data1, data2_padded], dim=0)
