import torch
from torch.utils.data import Dataset

import os
from itertools import islice
from math import inf

import logging

class ProcessedDataset(Dataset):
    """
    Data structure for a pre-processed cormorant dataset.  Extends PyTorch Dataset.

    Parameters
    ----------
    data : dict
        Dictionary of arrays containing molecular properties.
    included_species : tensor of scalars, optional
        Atomic species to include in ?????.  If None, uses all species.
    num_pts : int, optional
        Desired number of points to include in the dataset.
        Default value, -1, uses all of the datapoints.
    normalize : bool, optional
        ????? IS THIS USED?
    shuffle : bool, optional
        If true, shuffle the points in the dataset.
    subtract_thermo : bool, optional
        If True, subtracts the thermochemical energy of the atoms from each molecule in GDB9.
        Does nothing for other datasets.

    """
    def __init__(self, data, included_species=None, included_formal_charges=None, num_pts=-1, normalize=True, shuffle=True, subtract_thermo=True, use_ghost_nodes=False):
        """
        mainly does the following:
        loads the specified number of data points
        creates one hot representations for atomic_features (atomic_number and formal charges)
        """
        self.data = data

        if num_pts < 0:
            self.num_pts = len(data['atomic_numbers'])
        else:
            if num_pts > len(data['atomic_numbers']):
                logging.warning('Desired number of points ({}) is greater than the number of data points ({}) available in the dataset!'.format(num_pts, len(data['atomic_numbers'])))
                self.num_pts = len(data['atomic_numbers'])
            else:
                self.num_pts = num_pts

        # If included species is not specified
        if included_species is None:
            included_species = torch.unique(self.data['atomic_numbers'], sorted=True)
            if not use_ghost_nodes:
                if included_species[0] == 0:
                    included_species = included_species[1:]

        if use_ghost_nodes:
            # all molecules have the same number of nodes, which is the maximum across all dataset
            self.data['num_atoms'][:] = self.data['atomic_numbers'].size(1)

        if subtract_thermo:
            thermo_targets = [key.split('_')[0] for key in data.keys() if key.endswith('_thermo')]
            if len(thermo_targets) == 0:
                logging.warning('No thermochemical targets included! Try reprocessing dataset with --force-download!')
            else:
                logging.info('Removing thermochemical energy from targets {}'.format(' '.join(thermo_targets)))
            for key in thermo_targets:
                data[key] -= data[key + '_thermo'].to(data[key].dtype)

        self.included_species = included_species
        self.included_formal_charges = included_formal_charges

        # 'atomic_numbers_one_hot' will contain the one-hot representation of atomic numbers
        self.data['atomic_numbers_one_hot'] = self.data['atomic_numbers'].unsqueeze(-1) == included_species.unsqueeze(0).unsqueeze(0)

        # make the formal charges one-hot encoded
        # TODO: remove hardcoded values, might change with dataset
        #possible_formal_charges = torch.unique(self.data['formal_charges'], sorted=True) # most often will be [-1, 0, 1]
        # possible_formal_charges = torch.Tensor([-1, 0, 1])
        formal_charges_one_hot = self.data['formal_charges'].unsqueeze(-1) == included_formal_charges.unsqueeze(0).unsqueeze(0)
        # make sure we're covering all charges values by checking if the one-hot codes have at least one 1
        assert torch.all(torch.any(formal_charges_one_hot, -1)), 'There is an extra value of formal charge not accounted for.'
        self.data['formal_charges_one_hot'] = formal_charges_one_hot

        self.num_species = len(included_species)
        self.max_charge = max(included_species)

        self.parameters = {'num_species': self.num_species, 'max_charge': self.max_charge}

        # Get a dictionary of statistics for all properties that are one-dimensional tensors.
        self.calc_stats()

        if shuffle:
            self.perm = torch.randperm(len(data['atomic_numbers']))[:self.num_pts]
        else:
            self.perm = None

    def calc_stats(self):
        self.stats = {key: (val.mean(), val.std()) for key, val in self.data.items() if type(val) is torch.Tensor and val.dim() == 1 and val.is_floating_point()}

    def convert_units(self, units_dict):
        for key in self.data.keys():
            if key in units_dict:
                self.data[key] *= units_dict[key]

        self.calc_stats()

    def __len__(self):
        return self.num_pts

    def __getitem__(self, idx):
        if self.perm is not None:
            idx = self.perm[idx]
        return {key: val[idx] for key, val in self.data.items()}
