import numpy as np

from ...data.spectrum_dataset import SpectrumDataset
from ...data.spectrum_dataset import SpectrumPairDataset
from ...data.spectrum_transform import SpectrumTokenMZTransform
from ...data.spectrum_transform import SpectrumNumericalMZTransform
from ...data.spectrum_collater import SpectrumTokenMZCollater
from ...data.spectrum_collater import TanimotoPairsSpectrumTokenMZCollater
from ...data.spectrum_collater import SpectrumNumericalMZCollater
from ...data.spectrum_collater import TanimotoPairsSpectrumNumericalMZCollater
from ..base_net.peak_embedding import TokenMZEmbedding
from ..base_net.peak_embedding import NumericalMZEmbedding
from ..base_net.peak_embedding import SinusoidalMZEmbedding
from ..base_net.peak_embedding import PeakEmbedding
from ..base_net.attention_base import SelfAttentionLayer
from ..base_net.attention_base import AttentionHead
from ..base_net.attention_base import SANet
from .model import Model

class TransformerModel(Model):
    def __init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision, peak_limits,
            model_body_params, model_foot_params, model_head_params):
        Model.__init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision)

        self._peak_limits = peak_limits

        self._body = model_body_params
        self._foot = model_foot_params
        self._head = model_head_params

    def _get_attention_head(self):
        if self._head['query_on'] == 'precursor' and self._head['include_skip_connection']:
            attention_head = SelfAttentionLayer(
                collapse_seq=True,
                d_model = self._body['embd_dim'],
                nhead = self._head['num_heads'] if 'num_heads' in self._head
                    else self._body['num_heads'],
                ff_dim = self._head['ff_dim'] if 'ff_dim' in self._head
                    else self._body['ff_dim'],
                dropout = self._head['dropout'] if 'dropout' in self._head
                    else self._body['dropout']
            )

        elif self._head['query_on'] == 'precursor' and not self._head['include_skip_connection']:
            attention_head = AttentionHead(
                embd_dim = self._body['embd_dim'],
                num_heads = self._head['num_heads'] if 'num_heads' in self._head
                    else self._body['num_heads'],
                ff_dim = self._head['ff_dim'] if 'ff_dim' in self._head
                    else self._body['ff_dim'],
                dropout = self._head['dropout'] if 'dropout' in self._head
                    else self._body['dropout']
            )

        else:
            raise ValueError('Unknown attention head type')

        return attention_head

    def _setup_base_model(self):
        peak_embedding = self._get_peak_embedding()
        attention_head = self._get_attention_head()

        self._base_net = SANet(
            peak_embedding, attention_head,
            dropout = self._body['dropout'],
            num_sa_layers = self._body['num_sa_layers'],
            embd_dim = self._body['embd_dim'],
            num_heads = self._body['num_heads'],
            ff_dim = self._body['ff_dim']
        )

class TransformerTokenMZModel(TransformerModel):
    def _get_peak_embedding(self):
        mz_embedding = TokenMZEmbedding(
            spectrum_token_transform = self._spectrum_transform['train'],
            embd_dim = self._body['embd_dim'])

        peak_embedding = PeakEmbedding(
            mz_embd=mz_embedding,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['peak_hidden_dim'] if 'peak_hidden_dim' in self._foot else None)

        return peak_embedding

    def _setup_spectrum_transform(self):
        train_keep_idx = np.array([
            i for i, spectrum in enumerate(SpectrumDataset.load(self._train_data_path))
            if any((spectrum.peaks.mz >= self._peak_limits[0]) &
                    (spectrum.peaks.mz <= self._peak_limits[1])) and
                self._peak_limits[0] <= spectrum.metadata['precursor_mz'] <=
                    self._peak_limits[1]])

        spectrum_transform = (SpectrumTokenMZTransform(
                peak_limits = self._peak_limits,
                precision = self._foot['precision'] if 'precision' in self._foot else 1,
                normalize_peak_intensities = True)
            .fit(SpectrumDataset.load(self._train_data_path, keep_idx=train_keep_idx)))

        self._spectrum_transform = {'train': spectrum_transform, 'test': spectrum_transform}

    def _setup_spectra_collate(self):
        self._spectra_collate = {
            'pair': TanimotoPairsSpectrumTokenMZCollater(), 'single': SpectrumTokenMZCollater()}

class TransformerNumMZModel(TransformerModel):
    def _get_peak_embedding(self):
        mz_embedding = NumericalMZEmbedding(
            embd_dim = self._body['embd_dim'],
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['mz_hidden_dim'] if 'mz_hidden_dim' in self._foot else None,
            include_fractional_mz = self._foot['include_fractional_mz'])

        peak_embedding = PeakEmbedding(
            mz_embd=mz_embedding,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['peak_hidden_dim'] if 'peak_hidden_dim' in self._foot else None)

        return peak_embedding

    def _setup_spectrum_transform(self):
        spectrum_transform = SpectrumNumericalMZTransform(
            peak_limits=self._peak_limits, normalize_peak_intensities=True,
            include_fractional_mz=self._foot['include_fractional_mz'])

        self._spectrum_transform = {'train': spectrum_transform, 'test': spectrum_transform}

    def _setup_spectra_collate(self):
        self._spectra_collate = {
            'pair': TanimotoPairsSpectrumNumericalMZCollater(),
            'single': SpectrumNumericalMZCollater()}

class TransformerSinMZModel(TransformerModel):
    def _get_peak_embedding(self):
        mz_embedding = SinusoidalMZEmbedding(
            embd_dim = self._body['embd_dim'],
            mz_log_lims =  self._foot['mz_log_lims'] if 'mz_log_lims' in self._foot else None,
            mz_spacing = self._foot['mz_spacing'] if 'mz_spacing' in self._foot else None,
            mz_precision = self._foot['mz_precision'] if 'mz_precision' in self._foot else None,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['mz_hidden_dim'] if 'mz_hidden_dim' in self._foot else None,
            train_frequency = self._foot['train_frequency'] if 'train_frequency' in self._foot
                else None)

        peak_embedding = PeakEmbedding(
            mz_embd=mz_embedding,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['peak_hidden_dim'] if 'peak_hidden_dim' in self._foot else None,
            drop_precursor = self._foot['drop_precursor'] if 'drop_precursor' in self._foot
                else False)

        return peak_embedding

    def _setup_spectrum_transform(self):
        spectrum_transform = SpectrumNumericalMZTransform(
            peak_limits=self._peak_limits, normalize_peak_intensities=True,
            include_fractional_mz=False)

        self._spectrum_transform = {'train': spectrum_transform, 'test': spectrum_transform}

    def _setup_spectra_collate(self):
        if 'mz_precision' in self._foot and isinstance(self._foot['mz_precision'], int):
            self._spectra_collate = {
                'pair': TanimotoPairsSpectrumNumericalMZCollater(self._foot['mz_precision']),
                'single': SpectrumNumericalMZCollater(self._foot['mz_precision'])}
        else:
            self._spectra_collate = {
                'pair': TanimotoPairsSpectrumNumericalMZCollater(),
                'single': SpectrumNumericalMZCollater()}
