import numpy as np

from matchms.filtering import normalize_intensities

from ...data.spectrum_dataset import SpectrumDataset
from ...data.spectrum_dataset import SpectrumPairDataset
from ...data.spectrum_transform import ComposeSpectrumTransform
from ...data.spectrum_transform import SpectrumBinningTransform
from ...data.spectrum_collater import SpectrumBinningCollater
from ...data.spectrum_collater import TanimotoPairsSpectrumBinningCollater
from ..base_net.ffnn_base import FFNNet
from .model import Model

class FFNNModel(Model):
    def __init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision, layers, dropout, batch_norm,
            l1, l2):
        Model.__init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision)

        self._layers = layers
        self._dropout = dropout
        self._batch_norm = batch_norm
        self._l1 = l1
        self._l2 = l2

    def _setup_base_model(self):
        self._base_net = FFNNet(
            layers=self._layers, dropout=self._dropout, batch_norm=self._batch_norm, l1=self._l1,
            l2=self._l2)

    def _setup_spectrum_transform(self):
        train_dataset = SpectrumDataset.load(self._train_data_path)
        spectrum_binning_transform = SpectrumBinningTransform().fit(train_dataset)

        spectrum_transform = ComposeSpectrumTransform([
            normalize_intensities,
            spectrum_binning_transform
        ])

        self._spectrum_transform = {'train': spectrum_transform, 'test': spectrum_transform}

    def _setup_spectra_collate(self):
        self._spectra_collate = {
            'pair': TanimotoPairsSpectrumBinningCollater(), 'single': SpectrumBinningCollater()}
