"""Transformer encoder and decoder for the de novo sequencing task."""
import torch
import numpy as np
from collections.abc import Callable

from depthcharge.tokenizers import Tokenizer, PeptideTokenizer
from depthcharge.encoders import PeakEncoder, FloatEncoder, PositionalEncoder
from depthcharge.transformers import SpectrumTransformerEncoder, AnalyteTransformerDecoder
from depthcharge import utils

import inspect

import os
import pickle


class PeptideDecoder(AnalyteTransformerDecoder):
    """A transformer decoder for peptide sequences

    Parameters
    ----------
    n_tokens : int
        The number of tokens used to tokenize peptide sequences.
    d_model : int, optional
        The latent dimensionality to represent peaks in the mass spectrum.
    nhead : int, optional
        The number of attention heads in each layer. ``d_model`` must be
        divisible by ``nhead``.
    dim_feedforward : int, optional
        The dimensionality of the fully connected layers in the Transformer
        layers of the model.
    n_layers : int, optional
        The number of Transformer layers.
    dropout : float, optional
        The dropout probability for all layers.
    pos_encoder : PositionalEncoder or bool, optional
        The positional encodings to use for the amino acid sequence. If
        ``True``, the default positional encoder is used. ``False`` disables
        positional encodings, typically only for ablation tests.
    max_charge : int, optional
        The maximum charge state for peptide sequences.
    """

    def __init__(
        self,
        n_tokens: int | Tokenizer,
        d_model: int = 128,
        n_head: int = 8,
        dim_feedforward: int = 1024,
        n_layers: int = 1,
        dropout: float = 0,
        positional_encoder: PositionalEncoder | bool = True,
        padding_int: int | None = None,
        max_charge: int = 10,
    ) -> None:
        """Initialize a PeptideDecoder."""

        super().__init__(
            n_tokens=n_tokens,
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=dim_feedforward,
            n_layers=n_layers,
            dropout=dropout,
            positional_encoder=positional_encoder,
            padding_int=padding_int,
        )

        # As this does no longer use token encoders and final, it must be removed from the model, otherwise weight loading will become troubled
        # These will not be used but are set to be the same size as the checkpoint
        self.token_encoder = torch.nn.Embedding(
            27,
            512,
            padding_idx=0,
        )
        self.final = torch.nn.Linear(
            512,
            27,
        )
        # Permanent fix:
        # self.token_encoder = None
        # self.final = None

        self.charge_encoder = torch.nn.Embedding(max_charge, d_model//2)
        self.mass_encoder = FloatEncoder(d_model//2)

        # Encoder for prefix and suffix masses
        self.tokenMassEncoder = FloatEncoder(d_model//2)
        self.suffixMassEncoder = FloatEncoder(d_model//2)

        # Precursor mass encoders
        self.precursorMassEncoder = FloatEncoder(d_model//4)
        self.precursorSuffixMassEncoder = FloatEncoder(d_model//4)

        ## Mass look-up:
        self.tokenizer = n_tokens
        assert isinstance(self.tokenizer, PeptideTokenizer)
        
        # Create mass lookup table from tokenizer
        self._num_idx = len(self.tokenizer.reverse_index)
        self._pad_idx = self.tokenizer.padding_int
        self._start_idx = self.tokenizer.start_int
        self._end_idx = self.tokenizer.stop_int
        self._mass_lookup_table = [self.tokenizer.residues.get(a, 0.0) for a in self.tokenizer.reverse_index]
        #self._is_seen_token = [a not in self.tokenizer.expanded_residues for a in self.tokenizer.reverse_index]
        self._aa_idx = list(range(0, self._num_idx))

        # Calibrated Stacking factor
        #self._calibration_factor = self.tokenizer.calibration_factor if self.tokenizer.calibration_factor is not None else 1

        assert (d_model & (d_model-1) == 0) and d_model != 0, f"Model Dimensions must be a power of two but were {d_model}."

        finalLinears = []
        xin = d_model
        for xout in [512,1024,1024][::-1]:
            finalLinears.append(torch.nn.Linear(xin,xout))
            finalLinears.append(torch.nn.PReLU())
            xin = xout
        finalLinears.append(torch.nn.Linear(xin,1)) # scalar masses
        self.finalLinears = torch.nn.Sequential(*finalLinears)

        self.softmax = torch.nn.Softmax(2)

        self.final_mass_encoder = FloatEncoder(d_model)
        self.start_token_mass = torch.nn.Parameter(torch.randn(1))
        self.end_token_mass = torch.nn.Parameter(torch.randn(1))


    def forward(
        self,
        token_masses: torch.Tensor | None,
        precursors: torch.Tensor,
        memory: torch.Tensor | None,
        memory_key_padding_mask: torch.Tensor | None = None,
        memory_mask: torch.Tensor | None = None,
        tgt_mask: torch.Tensor | None = None,
        **kwargs: dict,
    ) -> torch.Tensor:
        """Decode a collection of sequences.

        Parameters
        ----------
        token_masses : torch.Tensor, or None
            The partial molecular sequences (of masses) for which to predict the next
            token's mass. Shaped (n_sequences, n_tokens)
        precursors : torch.Tensor
            The precursor information including mass and charge (mz, charge, pep_mass) of shape (n_sequences, 3)
        memory : torch.Tensor of shape (batch_size, len_seq, d_model)
            The representations from a ``TransformerEncoder``, such as a
            ``SpectrumTransformerEncoder``.
        memory_key_padding_mask : torch.Tensor of shape (batch_size, len_seq)
            Passed to `torch.nn.TransformerEncoder.forward()`. The mask that
            indicates which elements of ``memory`` are padding.
        memory_mask : torch.Tensor
            Passed to `torch.nn.TransformerEncoder.forward()`. The mask
            for the memory sequence.
        tgt_mask : torch.Tensor or None
            Passed to `torch.nn.TransformerEncoder.forward()`. The default
            is a mask that is suitable for predicting the next element in
            the sequence.
        **kwargs : dict
            Additional data fields. These may be used by overwriting
            the `global_token_hook()` method in a subclass.

        Returns
        -------
        scores : torch.Tensor of size (batch_size, len_sequence, n_tokens)
            The raw output for the final linear layer. These can be Softmax
            transformed to yield the probability of each token for the
            prediction.

        """
        # Prepare sequences, get the masses of the tokens, and get the suffix masses
        if token_masses is None:
            token_masses = torch.tensor([[]]).to(self.device)
            suffix_masses = torch.tensor([[]]).to(self.device)
        else:
            if token_masses.shape[-1] == 1:
                token_masses = token_masses.squeeze(-1)
            suffix_masses = precursors[:, [0]] - token_masses.cumsum(dim=-1)

        # Encode tokens, their token-, and suffix masses
        encoded_token_mass = self.tokenMassEncoder(token_masses)
        encoded_suffix_mass = self.suffixMassEncoder(suffix_masses)

        # Add the token and suffix masses to the encoded tokens
        encoded = torch.cat([encoded_token_mass, encoded_suffix_mass], dim=-1)

        # Add precurosor information with token mass = 0 and suffix_mass = mz_pep
        precursors_token_mass = torch.zeros(precursors.size()[0], 1, device=self.device)
        precursors_token_mass = self.precursorMassEncoder(precursors_token_mass)
        precursors_suffix_mass = self.precursorSuffixMassEncoder(precursors[:, None, 0])

        masses = self.mass_encoder(precursors[:, None, 0]).squeeze(1)
        charges = self.charge_encoder(precursors[:, 1].int() - 1)
        precursors = masses + charges

        precursors = precursors[:, None, :]
        precursors = torch.cat([precursors, precursors_token_mass, precursors_suffix_mass], dim=-1)

        encoded = torch.cat([precursors, encoded], dim=1)

        # Create the padding mask:
        tgt_key_padding_mask = encoded.sum(axis=2) == 0
        tgt_key_padding_mask[:, 0] = False

        # Feed through model:
        encoded = self.positional_encoder(encoded)

        if tgt_mask is None:
            tgt_mask = utils.generate_tgt_mask(encoded.shape[1]).to(
                self.device
            )

        emb = self.transformer_decoder(
            tgt=encoded,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            memory_mask=memory_mask,
        )

        mass_lookup_table = torch.tensor(self._mass_lookup_table, device=self.device)
        if self._start_idx is not None:
            mass_lookup_table[self._start_idx] = self.start_token_mass
        if self._end_idx is not None:
            mass_lookup_table[self._end_idx] = self.end_token_mass
        mass_lookup_table[0] = 0 # padding mass

        predicted_masses = self.finalLinears(emb)

        # # Find closest masses
        # sign_diff = predicted_masses - mass_lookup_table
        # diff = torch.abs(sign_diff)

        # # Snap to closest mass
        # _, min_idx = torch.min(diff, dim=2)
        # predicted_masses -= torch.gather(sign_diff, -1, min_idx.unsqueeze(-1)).detach()

        # # Find closest masses
        # diff = torch.abs(predicted_masses - mass_lookup_table)

        # # Snap to closest mass
        # _, min_idx = torch.min(diff, dim=2)
        # predicted_masses = mass_lookup_table[min_idx].unsqueeze(dim=-1)

        # Find closest masses
        diff = torch.abs(predicted_masses - mass_lookup_table)

        return self.softmax(-diff), predicted_masses


class SpectrumEncoder(SpectrumTransformerEncoder):
    """A Transformer encoder for input mass spectra.

    Parameters
    ----------
    d_model : int, optional
        The latent dimensionality to represent peaks in the mass spectrum.
    n_head : int, optional
        The number of attention heads in each layer. ``d_model`` must be
        divisible by ``n_head``.
    dim_feedforward : int, optional
        The dimensionality of the fully connected layers in the Transformer
        layers of the model.
    n_layers : int, optional
        The number of Transformer layers.
    dropout : float, optional
        The dropout probability for all layers.
    peak_encoder : bool, optional
        Use positional encodings m/z values of each peak.
    dim_intensity: int or None, optional
        The number of features to use for encoding peak intensity.
        The remaining (``d_model - dim_intensity``) are reserved for
        encoding the m/z value.
    """

    def __init__(
        self,
        d_model: int = 128,
        n_head: int = 8,
        dim_feedforward: int = 1024,
        n_layers: int = 1,
        dropout: float = 0,
        peak_encoder: PeakEncoder | Callable | bool = True,
    ):
        """Initialize a SpectrumEncoder"""
        super().__init__(d_model, n_head, dim_feedforward,
                         n_layers, dropout, peak_encoder)

        self.latent_spectrum = torch.nn.Parameter(torch.randn(1, 1, d_model))

    def global_token_hook(
        self,
        mz_array: torch.Tensor,
        intensity_array: torch.Tensor,
        *args: torch.Tensor,
        **kwargs: dict,
    ) -> torch.Tensor:
        """Override global_token_hook to include
        lantent_spectrum parameter

        Parameters
        ----------
        mz_array : torch.Tensor of shape (n_spectra, n_peaks)
            The zero-padded m/z dimension for a batch of mass spectra.
        intensity_array : torch.Tensor of shape (n_spectra, n_peaks)
            The zero-padded intensity dimension for a batch of mass spctra.
        *args : torch.Tensor
            Additional data passed with the batch.
        **kwargs : dict
            Additional data passed with the batch.

        Returns
        -------
        torch.Tensor of shape (batch_size, d_model)
            The precursor representations.

        """
        return self.latent_spectrum.squeeze(0).expand(mz_array.shape[0], -1)
