from operator import itemgetter
from collections import Counter
from itertools import groupby

import numpy as np

from sklearn.preprocessing import KBinsDiscretizer
from sklearn.preprocessing import StandardScaler

from matchms import Spectrum
from matchms import Spikes
from matchms.filtering import normalize_intensities

class ComposeSpectrumTransform(object):
    def __init__(self, transforms):
        self._transforms = transforms

    def __call__(self, spectrum):
        for transform in self._transforms:
            spectrum = transform(spectrum)

        return spectrum

class SpectrumPropertyTransform(object):
    def __init__(self, spectrum_transform):
        self._spectrum_transform = spectrum_transform
        self._scaler = StandardScaler()
        self._properties = [
            'alogp', 'num_hb_acceptors', 'num_hb_donors', 'polar_surface_area',
            'num_rotatable_bonds', 'num_aromatic_rings', 'num_aliphatic_rings', 'num_heteroatoms',
            'frac_sp3_carbons', 'qed']

    def fit(self, dataset):
        self._scaler.fit(np.array([
            [spectrum.get(property) for property in self._properties]
            for spectrum in dataset._spectra]))

        return self

    def __call__(self, spectrum):
        properties = np.array([[spectrum.get(property) for property in self._properties]])
        properties = self._scaler.transform(properties).reshape((-1,))

        return self._spectrum_transform(spectrum), properties

class AugmentationSpectrumTransform(object):
    def __init__(self,
            removal_max=0.3, removal_intensity=0.2, intensity=0.4, noise_max=10,
            noise_intensity=0.01, mz_window=(10, 1000)):
        self._removal_max = removal_max
        self._removal_intensity = removal_intensity
        self._intensity = intensity
        self._noise_max = noise_max
        self._noise_intensity = noise_intensity
        self._mz_window = mz_window

    def __call__(self, spectrum):
        augmented_spectrum = spectrum.clone()

        # First nomralize intensities
        augmented_spectrum = normalize_intensities(augmented_spectrum)

        mz, intensities = augmented_spectrum.peaks.mz, augmented_spectrum.peaks.intensities

        # Augmentation 1: peak removal (peaks < augment_removal_max)
        keep_idx = np.where(
            (intensities >= self._removal_intensity) |
            (np.random.rand(len(intensities)) >= np.random.rand() * self._removal_max))[0]
        if len(keep_idx) > 0:
            mz, intensities = mz[keep_idx], intensities[keep_idx]

        # Augmentation 2: Change peak intensities
        intensities = (1 - self._intensity * 2 * (np.random.rand(*intensities.shape) - 0.5)) * intensities

        # Augmentation 3: Peak addition
        new_mz = [
            np.random.rand() * (self._mz_window[1] - self._mz_window[0]) + self._mz_window[0]
            for _ in range(np.random.randint(0, self._noise_max + 1))]
        new_intensities = [np.random.rand() * self._noise_intensity for _ in range(len(new_mz))]

        mz, intensities = zip(*sorted(zip(
            np.concatenate((mz, new_mz)),
            np.concatenate((intensities, new_intensities))), key=itemgetter(0)))

        augmented_spectrum.peaks = Spikes(np.array(mz), np.array(intensities))

        # Finally nomralize intensities post noising
        augmented_spectrum = normalize_intensities(augmented_spectrum)

        return augmented_spectrum

class SpectrumBinningTransform(object):
    def __init__(
            self, num_bins=10000, mz_window=(10.0, 1000.0), peak_scaling=0.5,
            allowed_missing_percentage=0.0):
        self._num_bins = num_bins
        self._mz_window = mz_window
        self._peak_scaling = peak_scaling
        self._allowed_missing_percentage = allowed_missing_percentage

        self._binner = KBinsDiscretizer(n_bins=self._num_bins, encode='ordinal', strategy='uniform')
        self._bin_map = None

    def fit(self, dataset):
        mz = np.concatenate([dataset[i].peaks.mz.reshape((-1, 1)) for i in range(len(dataset))])
        mz = mz[np.where((mz >= self._mz_window[0]) & (mz <= self._mz_window[1]))[0]]
        mz_bin_count = Counter(self._binner.fit_transform(mz).reshape((-1)).astype(int))
        self._bin_map = dict(zip(sorted(mz_bin_count.keys()), range(len(mz_bin_count.keys()))))

        return self

    def __call__(self, spectrum):
        mz = spectrum.peaks.mz
        spectrum_idx = np.where((mz >= self._mz_window[0]) & (mz <= self._mz_window[1]))[0]
        mz = mz[spectrum_idx]
        intensities = spectrum.peaks.intensities[spectrum_idx] ** self._peak_scaling

        try:
            binned_spectrum = self._binner.transform(mz.reshape((-1, 1))).reshape(-1).astype(int)
            binned_spectrum = groupby(sorted([
                (self._bin_map[bin_id], intensity)
                for bin_id, intensity in zip(binned_spectrum, intensities)
                if bin_id in self._bin_map], key=itemgetter(0)), key=itemgetter(0))
            binned_mz, binned_intensities = zip(*[
                (bin_id, sum(map(itemgetter(1), bin_group)))
                for bin_id, bin_group in binned_spectrum])

            binned_spectrum = np.zeros(len(self._bin_map))
            binned_spectrum[np.array(binned_mz)] = np.array(binned_intensities)
        except ValueError:
            binned_spectrum = np.zeros(len(self._bin_map))

        return binned_spectrum

class MZTokenMissingDict(dict):
    def __missing__(self, key):
        return self['[unk]']

class SpectrumTokenMZTransform(object):
    def __init__(
            self, n_required=10, ratio_desired=0.5, n_max=500, precision=1,
            peak_limits=(0, 1000), normalize_peak_intensities=False):
        self._n_required = n_required
        self._n_max = n_max
        self._peak_limits = peak_limits
        self._normalize_peak_intensities = normalize_peak_intensities
        self._fmt_str = '{}@{{:.{}f}}'.format('peak', precision)
        self._token2id = None
        self._vocab_size = None

    @property
    def vocab_size(self):
        return self._vocab_size

    @property
    def token2id(self):
        return self._token2id

    @property
    def unk_id(self):
        return self._token2id['[unk]']

    @property
    def pad_id(self):
        return self._token2id['[pad]']

    def fit(self, dataset):
        token_set = set()
        for i in range(len(dataset)):
            spectrum = dataset[i].clone()

            if self._normalize_peak_intensities:
                spectrum = normalize_intensities(spectrum)

            mz_arr = np.insert(spectrum.peaks.mz, 0, spectrum.metadata['precursor_mz'])
            intensity_arr = np.insert(spectrum.peaks.intensities, 0, 2.0)

            mz_arr, _ = zip(*sorted([(mz, intensity)
                    for mz, intensity in zip(mz_arr, intensity_arr)
                    if self._peak_limits[0] <= mz <= self._peak_limits[1]],
                key=itemgetter(1), reverse=True)[:self._n_max])

            token_set |= set([
                self._fmt_str.format(mz) for mz in mz_arr
                if self._peak_limits[0] <= mz <= self._peak_limits[1]])

        self._token2id = MZTokenMissingDict(zip(
            ['[unk]', '[pad]'] + sorted(token_set), range(len(token_set) + 2)))

        self._vocab_size = len(self._token2id)

        return self

    def __call__(self, spectrum):
        spectrum = spectrum.clone()

        if self._normalize_peak_intensities:
            spectrum = normalize_intensities(spectrum)

        mz_arr = np.insert(spectrum.peaks.mz, 0, spectrum.metadata['precursor_mz'])
        intensity_arr = np.insert(spectrum.peaks.intensities, 0, 2.0)


        mz_arr, intensity_arr = zip(*sorted([(mz, intensity)
                for mz, intensity in zip(mz_arr, intensity_arr)],
            key=itemgetter(1), reverse=True)[:self._n_max])

        mz_arr = np.array([self._token2id[self._fmt_str.format(mz)] for mz in mz_arr])
        mz_arr = np.pad(
            mz_arr, (0, self._n_max - len(mz_arr)),
            constant_values=self._token2id['[pad]'])

        intensity_arr = np.array(intensity_arr)
        intensity_arr = np.pad(
            intensity_arr, (0, self._n_max - len(intensity_arr)),
            constant_values=0.0)

        return mz_arr, intensity_arr

class SpectrumNumericalMZTransform(object):
    def __init__(
            self, n_max=512, peak_limits=(0, 1000), include_fractional_mz=False,
            normalize_peak_intensities=False):
        self._n_max = n_max
        self._peak_limits = peak_limits
        self._normalize_peak_intensities = normalize_peak_intensities
        self._include_fractional_mz = include_fractional_mz

    def __call__(self, spectrum):
        spectrum = spectrum.clone()

        if self._normalize_peak_intensities:
            spectrum = normalize_intensities(spectrum)

        mz_arr = np.insert(spectrum.peaks.mz, 0, spectrum.metadata['precursor_mz'])
        intensity_arr = np.insert(spectrum.peaks.intensities, 0, 2.0)

        try:
            mz_arr, intensity_arr = zip(*sorted([(mz, intensity)
                    for mz, intensity in zip(mz_arr, intensity_arr)
                    if self._peak_limits[0] <= mz <= self._peak_limits[1]],
                key=itemgetter(1), reverse=True)[:self._n_max])
        except ValueError:
            mz_arr, intensity_arr = np.array([]), np.array([])

        mz_arr = np.pad(mz_arr, (0, self._n_max - len(mz_arr)), constant_values=0.0)
        mz_arr = mz_arr.reshape((-1, 1))
        if self._include_fractional_mz:
            mz_arr = np.hstack((mz_arr, np.floor(mz_arr), mz_arr - np.floor(mz_arr)))

        intensity_arr = np.array(intensity_arr)
        intensity_arr = np.pad(
            intensity_arr, (0, self._n_max - len(intensity_arr)), constant_values=0.0)

        return mz_arr, intensity_arr
