import numpy as np

from torch.utils.data import DataLoader

from matchms.filtering import normalize_intensities

from ...data.spectrum_dataset import SpectrumDataset
from ...data.spectrum_dataset import SpectrumPairDataset
from ...data.spectrum_transform import ComposeSpectrumTransform
from ...data.spectrum_transform import SpectrumPropertyTransform
from ...data.spectrum_transform import SpectrumBinningTransform
from ...data.spectrum_transform import AugmentationSpectrumTransform
from ...data.spectrum_collater import SpectrumBinningCollater
from ...data.spectrum_collater import PropertyPredictionSpectrumBinningCollater
from ..base_net.property_prediction import PropertyPredictionNet
from ..base_net.ffnn_base import FFNNet
from ..base_net.peak_embedding import MultiLayerFeedForwardBlock

from .model import Model

class FFNNModel(Model):
    def __init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision, layers, dropout, batch_norm,
            l1, l2):
        Model.__init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision)

        self._layers = layers
        self._dropout = dropout
        self._batch_norm = batch_norm
        self._l1 = l1
        self._l2 = l2

    def get_data_loader(
            self, dataset, batch_size=None, shuffle=None, num_workers=None, prefetch_factor=None,
            drop_last=None):
        if isinstance(dataset[0], np.ndarray):
            collate_fn = self._spectra_collate['unlabeled']
        else:
            collate_fn = self._spectra_collate['labeled']

        return DataLoader(
            dataset,
            batch_size = self._batch_size if batch_size is None else batch_size,
            shuffle = False if shuffle is None else shuffle,
            collate_fn = collate_fn,
            num_workers = self._num_workers if num_workers is None else num_workers,
            prefetch_factor = self._prefetch_factor if prefetch_factor is None else prefetch_factor,
            drop_last = False if drop_last is None else drop_last)

    def _setup_base_model(self):
        base_model = FFNNet(
            layers=self._layers, dropout=self._dropout, batch_norm=self._batch_norm, l1=self._l1,
            l2=self._l2)

        prediction_head = MultiLayerFeedForwardBlock(
            in_dim = self._layers[-1],
            out_dim = 10,
            hidden_dim = self._layers[-1],
            dropout = self._dropout[0],
            dropout_last_layer = False)

        self._base_net = PropertyPredictionNet(base_model, prediction_head)

    def _setup_spectrum_transform(self):
        train_dataset = SpectrumDataset.load(self._train_data_path)
        spectrum_binning_transform = SpectrumBinningTransform().fit(train_dataset)

        base_spectrum_transform = ComposeSpectrumTransform([
            normalize_intensities,
            spectrum_binning_transform
        ])

        if self._augment_spectrum is not None:
            train_spectrum_transform = ComposeSpectrumTransform([
                AugmentationSpectrumTransform(**self._augment_spectrum),
                base_spectrum_transform
            ])
        else:
            train_spectrum_transform = base_spectrum_transform

        train_dataset = SpectrumDataset.load(self._train_data_path)
        train_transform = SpectrumPropertyTransform(train_spectrum_transform).fit(train_dataset)
        test_transform = SpectrumPropertyTransform(base_spectrum_transform).fit(train_dataset)
        pred_transform = base_spectrum_transform

        self._spectrum_transform = {
            'train': train_transform,
            'test': test_transform,
            'prediction': pred_transform}

    def _setup_spectra_collate(self):
        self._spectra_collate = {
            'labeled': PropertyPredictionSpectrumBinningCollater(),
            'unlabeled': SpectrumBinningCollater()}
