import numpy as np

from ...data.spectrum_dataset import SpectrumDataset
from ...data.spectrum_dataset import SpectrumPairDataset
from ...data.spectrum_transform import ComposeSpectrumTransform
from ...data.spectrum_transform import SpectrumPropertyTransform
from ...data.spectrum_transform import SpectrumTokenMZTransform
from ...data.spectrum_transform import SpectrumNumericalMZTransform
from ...data.spectrum_transform import AugmentationSpectrumTransform
from ...data.spectrum_collater import SpectrumTokenMZCollater
from ...data.spectrum_collater import PropertyPredictionSpectrumTokenMZCollater
from ...data.spectrum_collater import PropertyPredictionSpectrumNumericalMZCollater
from ...data.spectrum_collater import UntypedSpectrumNumericalMZCollater
from ...data.spectrum_collater import UntypedPropertyPredictionSpectrumNumericalMZCollater
from ...data.spectrum_collater import SpectrumNumericalMZCollater
from ...data.spectrum_collater import TanimotoPairsSpectrumNumericalMZCollater
from ..base_net.peak_embedding import TokenMZEmbedding
from ..base_net.peak_embedding import NumericalMZEmbedding
from ..base_net.peak_embedding import SinusoidalMZEmbedding
from ..base_net.peak_embedding import PeakEmbedding
from ..base_net.peak_embedding import MultiLayerFeedForwardBlock
from ..base_net.attention_base import SelfAttentionLayer
from ..base_net.attention_base import AttentionHead
from ..base_net.attention_base import MeanPool
from ..base_net.attention_base import SANet
from ..base_net.property_prediction import PropertyPredictionNet
from .model import Model

class TransformerModel(Model):
    def __init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision, peak_limits,
            model_body_params, model_foot_params, model_head_params, model_pred_head_params):
        Model.__init__(
            self, train_data_path, test_data_path, model_path, num_epochs, batch_size,
            num_batch_per_update, num_workers, augment_spectrum, prefetch_factor, optimizer,
            machine_params, wandb_args, gradient_clip_val, precision)

        self._peak_limits = peak_limits

        self._body = model_body_params
        self._foot = model_foot_params
        self._head = model_head_params
        self._pred_head = model_pred_head_params

    def _get_attention_head(self):
        if self._head['query_on'] == 'precursor' and self._head['include_skip_connection']:
            attention_head = SelfAttentionLayer(
                collapse_seq=True,
                d_model = self._body['embd_dim'],
                nhead = self._head['num_heads'] if 'num_heads' in self._head
                    else self._body['num_heads'],
                ff_dim = self._head['ff_dim'] if 'ff_dim' in self._head
                    else self._body['ff_dim'],
                dropout = self._head['dropout'] if 'dropout' in self._head
                    else self._body['dropout']
            )

        elif self._head['query_on'] == 'precursor' and not self._head['include_skip_connection']:
            attention_head = AttentionHead(
                embd_dim = self._body['embd_dim'],
                num_heads = self._head['num_heads'] if 'num_heads' in self._head
                    else self._body['num_heads'],
                ff_dim = self._head['ff_dim'] if 'ff_dim' in self._head
                    else self._body['ff_dim'],
                dropout = self._head['dropout'] if 'dropout' in self._head
                    else self._body['dropout']
            )

        elif self._head['query_on'] == 'meanpool':
            attention_head = MeanPool(
                embd_dim = self._body['embd_dim'],
                num_heads = self._head['num_heads'] if 'num_heads' in self._head
                    else self._body['num_heads'],
                ff_dim = self._head['ff_dim'] if 'ff_dim' in self._head
                    else self._body['ff_dim'],
                dropout = self._head['dropout'] if 'dropout' in self._head
                    else self._body['dropout']
            )

        else:
            raise ValueError('Unknown attention head type')

        return attention_head

    def _setup_base_model(self):
        peak_embedding = self._get_peak_embedding()
        attention_head = self._get_attention_head()

        prediction_head = MultiLayerFeedForwardBlock(
            in_dim = self._body['embd_dim'],
            out_dim = 10,
            hidden_dim = self._pred_head['hidden_dim'],
            dropout = self._pred_head['dropout'] if 'dropout' in self._pred_head
                else self._body['dropout'],
            dropout_last_layer = False)

        base_net = SANet(
            peak_embedding, attention_head,
            dropout = self._body['dropout'],
            num_sa_layers = self._body['num_sa_layers'],
            embd_dim = self._body['embd_dim'],
            num_heads = self._body['num_heads'],
            ff_dim = self._body['ff_dim'])

        self._base_net = PropertyPredictionNet(base_net, prediction_head)

class TransformerTokenMZModel(TransformerModel):
    def _get_peak_embedding(self):
        mz_embedding = TokenMZEmbedding(
            spectrum_token_transform = self._spectrum_transform['prediction'],
            embd_dim = self._body['embd_dim'])

        peak_embedding = PeakEmbedding(
            mz_embd=mz_embedding,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['peak_hidden_dim'] if 'peak_hidden_dim' in self._foot else None,
            drop_precursor = self._foot['drop_precursor'] if 'drop_precursor' in self._foot
                else False)

        return peak_embedding

    def _setup_spectrum_transform(self):
        train_keep_idx = np.array([
            i for i, spectrum in enumerate(SpectrumDataset.load(self._train_data_path))
            if any((spectrum.peaks.mz >= self._peak_limits[0]) &
                    (spectrum.peaks.mz <= self._peak_limits[1])) and
                self._peak_limits[0] <= spectrum.metadata['precursor_mz'] <=
                    self._peak_limits[1]])

        base_spectrum_transform = (SpectrumTokenMZTransform(
                peak_limits=self._peak_limits, normalize_peak_intensities=True)
            .fit(SpectrumDataset.load(self._train_data_path, keep_idx=train_keep_idx)))

        if self._augment_spectrum is not None:
            train_spectrum_transform = ComposeSpectrumTransform([
                AugmentationSpectrumTransform(**self._augment_spectrum),
                base_spectrum_transform
            ])
        else:
            train_spectrum_transform = base_spectrum_transform

        train_dataset = SpectrumDataset.load(self._train_data_path)
        train_transform = SpectrumPropertyTransform(train_spectrum_transform).fit(train_dataset)
        test_transform = SpectrumPropertyTransform(base_spectrum_transform).fit(train_dataset)
        pred_transform = base_spectrum_transform

        self._spectrum_transform = {
            'train': train_transform,
            'test': test_transform,
            'prediction': pred_transform}

    def _setup_spectra_collate(self):
        self._spectra_collate = {
            'labeled': PropertyPredictionSpectrumTokenMZCollater(),
            'unlabeled': SpectrumTokenMZCollater()}

class TransformerNumMZModel(TransformerModel):
    def _get_peak_embedding(self):
        mz_embedding = NumericalMZEmbedding(
            embd_dim = self._body['embd_dim'],
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['mz_hidden_dim'] if 'mz_hidden_dim' in self._foot else None,
            include_fractional_mz = self._foot['include_fractional_mz'])

        peak_embedding = PeakEmbedding(
            mz_embd=mz_embedding,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['peak_hidden_dim'] if 'peak_hidden_dim' in self._foot else None,
            drop_precursor = self._foot['drop_precursor'] if 'drop_precursor' in self._foot
                else False)

        return peak_embedding

    def _setup_spectrum_transform(self):
        base_spectrum_transform = SpectrumNumericalMZTransform(
            peak_limits=self._peak_limits, normalize_peak_intensities=True,
            include_fractional_mz=self._foot['include_fractional_mz'])

        if self._augment_spectrum is not None:
            train_spectrum_transform = ComposeSpectrumTransform([
                AugmentationSpectrumTransform(**self._augment_spectrum),
                base_spectrum_transform
            ])
        else:
            train_spectrum_transform = base_spectrum_transform

        train_dataset = SpectrumDataset.load(self._train_data_path)
        train_transform = SpectrumPropertyTransform(train_spectrum_transform).fit(train_dataset)
        test_transform = SpectrumPropertyTransform(base_spectrum_transform).fit(train_dataset)
        pred_transform = base_spectrum_transform

        self._spectrum_transform = {
            'train': train_transform,
            'test': test_transform,
            'prediction': pred_transform}

    def _setup_spectra_collate(self):
        self._spectra_collate = {
            'labeled': PropertyPredictionSpectrumNumericalMZCollater(),
            'unlabeled': SpectrumNumericalMZCollater()}

class TransformerSinMZModel(TransformerModel):
    def _get_peak_embedding(self):
        mz_embedding = SinusoidalMZEmbedding(
            embd_dim = self._body['embd_dim'],
            mz_log_lims =  self._foot['mz_log_lims'] if 'mz_log_lims' in self._foot else None,
            mz_spacing = self._foot['mz_spacing'] if 'mz_spacing' in self._foot else None,
            mz_precision = self._foot['mz_precision'] if 'mz_precision' in self._foot else None,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['mz_hidden_dim'] if 'mz_hidden_dim' in self._foot else None,
            train_frequency = self._foot['train_frequency'] if 'train_frequency' in self._foot
                else None)

        peak_embedding = PeakEmbedding(
            mz_embd=mz_embedding,
            dropout = self._foot['dropout'] if 'dropout' in self._foot else self._body['dropout'],
            hidden_dim = self._foot['peak_hidden_dim'] if 'peak_hidden_dim' in self._foot else None,
            drop_precursor = self._foot['drop_precursor'] if 'drop_precursor' in self._foot
                else False)

        return peak_embedding

    def _setup_spectrum_transform(self):
        base_spectrum_transform = SpectrumNumericalMZTransform(
            peak_limits=self._peak_limits, normalize_peak_intensities=True,
            include_fractional_mz=False)

        if self._augment_spectrum is not None:
            train_spectrum_transform = ComposeSpectrumTransform([
                AugmentationSpectrumTransform(**self._augment_spectrum),
                base_spectrum_transform
            ])
        else:
            train_spectrum_transform = base_spectrum_transform

        train_dataset = SpectrumDataset.load(self._train_data_path)
        train_transform = SpectrumPropertyTransform(train_spectrum_transform).fit(train_dataset)
        test_transform = SpectrumPropertyTransform(base_spectrum_transform).fit(train_dataset)
        pred_transform = base_spectrum_transform

        self._spectrum_transform = {
            'train': train_transform,
            'test': test_transform,
            'prediction': pred_transform}

    def _setup_spectra_collate(self):
        if 'mz_precision' in self._foot:
            if self._foot['mz_precision'] is None:
                self._spectra_collate = {
                    'labeled': UntypedPropertyPredictionSpectrumNumericalMZCollater(),
                    'unlabeled': UntypedSpectrumNumericalMZCollater()}
            elif isinstance(self._foot['mz_precision'], int):
                self._spectra_collate = {
                    'labeled': PropertyPredictionSpectrumNumericalMZCollater(
                        self._foot['mz_precision']),
                    'unlabeled': SpectrumNumericalMZCollater(self._foot['mz_precision'])}
            else:
                raise ValueError()
        else:
            self._spectra_collate = {
                'labeled': PropertyPredictionSpectrumNumericalMZCollater(),
                'unlabeled': SpectrumNumericalMZCollater()}
