"""A de novo peptide sequencing model."""

import collections
import heapq
import logging
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple

import einops
import torch
import numpy as np
import lightning.pytorch as pl

from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.encoders import FloatEncoder

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from io import BytesIO
from PIL import Image

from . import evaluate
from .. import config
from ..data import ms_io
from ..denovo.transformers import SpectrumEncoder, PeptideDecoder
from ..utils import EpochTracker

logger = logging.getLogger("casanovo")

class Spec2Pep(pl.LightningModule):
    """
    A Transformer model for de novo peptide sequencing.

    Use this model in conjunction with a pytorch-lightning Trainer.

    Parameters
    ----------
    dim_model : int
        The latent dimensionality used by the transformer model.
    n_head : int
        The number of attention heads in each layer. ``dim_model`` must be
        divisible by ``n_head``.
    dim_feedforward : int
        The dimensionality of the fully connected layers in the transformer
        model.
    n_layers : int
        The number of transformer layers.
    dropout : float
        The dropout probability for all layers.
    dim_intensity : Optional[int]
        The number of features to use for encoding peak intensity. The remaining
        (``dim_model - dim_intensity``) are reserved for encoding the m/z value.
        If ``None``, the intensity will be projected up to ``dim_model`` using a
        linear layer, then summed with the m/z encoding for each peak.
    max_length : int
        The maximum peptide length to decode.
    residues : Union[Dict[str, float], str]
        The amino acid dictionary and their masses. By default ("canonical) this
        is only the 20 canonical amino acids, with cysteine carbamidomethylated.
        If "massivekb", this dictionary will include the modifications found in
        MassIVE-KB. Additionally, a dictionary can be used to specify a custom
        collection of amino acids and masses.
    max_charge : int
        The maximum precursor charge to consider.
    precursor_mass_tol : float, optional
        The maximum allowable precursor mass tolerance (in ppm) for correct
        predictions.
    isotope_error_range : Tuple[int, int]
        Take into account the error introduced by choosing a non-monoisotopic
        peak for fragmentation by not penalizing predicted precursor m/z's that
        fit the specified isotope error:
        `abs(calc_mz - (precursor_mz - isotope * 1.00335 / precursor_charge))
        < precursor_mass_tol`
    min_peptide_len : int
        The minimum length of predicted peptides.
    n_beams : int
        Number of beams used during beam search decoding.
    top_match : int
        Number of PSMs to return for each spectrum.
    n_log : int
        The number of epochs to wait between logging messages.
    train_label_smoothing : float
        Smoothing factor when calculating the training loss.
    warmup_iters : int
        The number of iterations for the linear warm-up of the learning rate.
    cosine_schedule_period_iters : int
        The number of iterations for the cosine half period of the learning rate.
    out_writer : Optional[str]
        The output writer for the prediction results.
    calculate_precision : bool
        Calculate the validation set precision during training.
        This is expensive.
    tokenizer: Optional[PeptideTokenizer]
        Tokenizer object to tokenize and detokenize peptide sequences.
    **kwargs : Dict
        Additional keyword arguments passed to the Adam optimizer.
    """

    def __init__(
        self,
        dim_model: int = 512,
        n_head: int = 8,
        dim_feedforward: int = 1024,
        n_layers: int = 9,
        dropout: float = 0.0,
        dim_intensity: Optional[int] = None,
        max_length: int = 100,
        max_charge: int = 5,
        precursor_mass_tol: float = 50,
        isotope_error_range: Tuple[int, int] = (0, 1),
        min_peptide_len: int = 6,
        n_beams: int = 1,
        top_match: int = 1,
        n_log: int = 10,
        train_label_smoothing: float = 0.01,
        warmup_iters: int = 100_000,
        cosine_schedule_period_iters: int = 600_000,
        out_writer: Optional[ms_io.MztabWriter] = None,
        calculate_precision: bool = False,
        tokenizer: Optional[PeptideTokenizer] = None,
        epoch_tracker: Optional[EpochTracker] = None,
        **kwargs: Dict,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.tokenizer = tokenizer if tokenizer is not None else PeptideTokenizer()
        self.vocab_size = len(self.tokenizer) + 1 
        # Build the model.
        self.encoder = SpectrumEncoder(
            d_model=dim_model,
            n_head=n_head,
            dim_feedforward=dim_feedforward,
            n_layers=n_layers,
            dropout=dropout,
        )
        self.decoder = PeptideDecoder(
            d_model=dim_model,
            n_tokens=self.tokenizer,
            n_head=n_head,
            dim_feedforward=dim_feedforward,
            n_layers=n_layers,
            dropout=dropout,
            max_charge=max_charge,
        )

        self.mass_encoder = FloatEncoder(dim_model)
        self.dim_model = dim_model

        self.softmax = torch.nn.Softmax(2)
        ignore_index =  0
        self.celoss = torch.nn.CrossEntropyLoss(
            ignore_index=ignore_index, label_smoothing=train_label_smoothing
        )
        self.val_celoss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.mse_loss = MaskedMSELoss()
        # Optimizer settings.
        self.warmup_iters = warmup_iters
        self.cosine_schedule_period_iters = cosine_schedule_period_iters
        # `kwargs` will contain additional arguments as well as unrecognized
        # arguments, including deprecated ones. Remove the deprecated ones.
        for k in config._config_deprecated:
            kwargs.pop(k, None)
            warnings.warn(
                f"Deprecated hyperparameter '{k}' removed from the model.",
                DeprecationWarning,
            )
        self.opt_kwargs = kwargs

        # Data properties.
        self.max_length = max_length
        self.precursor_mass_tol = precursor_mass_tol
        self.isotope_error_range = isotope_error_range
        self.min_peptide_len = min_peptide_len
        self.n_beams = n_beams
        self.top_match = top_match
        
        self.stop_token = self.tokenizer.stop_int

        # Logging.
        self.calculate_precision = calculate_precision
        self.n_log = n_log
        self._history = []

        # Output writer during predicting.
        self.out_writer = out_writer

        # Epoch tracking
        self.epoch_tracker = epoch_tracker

        # Plotting counting helpers -> Check if the first batch of validation loop is running
        self.last_val = 0
        self.current_val = -1

    @property
    def device(self) -> torch.device:
        """The current device for first parameter of the model."""
        return next(self.parameters()).device

    @property
    def n_parameters(self):
        """The number of learnable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(
        self, batch: dict
    ) -> List[List[Tuple[float, np.ndarray, str]]]:
        """
        Predict peptide sequences for a batch of MS/MS spectra.

        Parameters
        ----------
        batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]
            A batch of (i) m/z values of MS/MS spectra, 
            (ii) intensity values of MS/MS spectra,
            (iii) precursor information, 
            (iv) peptide sequences as torch Tensors.

        Returns
        -------
        pred_peptides : List[List[Tuple[float, np.ndarray, str]]]
            For each spectrum, a list with the top peptide predictions. A
            peptide predictions consists of a tuple with the peptide score,
            the amino acid scores, and the predicted peptide sequence.
        """
        mzs, ints, precursors, _, _ = self._process_batch(batch)
        return self.beam_search_decode(mzs, ints, precursors) 

    def beam_search_decode(
        self, mzs: torch.Tensor, ints: torch.Tensor, precursors: torch.Tensor
    ) -> List[List[Tuple[float, np.ndarray, str]]]:
        """
        Beam search decoding of the spectrum predictions.

        Parameters
        ----------
        mzs : torch.Tensor of shape (n_spectra, n_peaks)
            The m/z axis of spectra for which to predict peptide sequences.
            Axis 0 represents an MS/MS spectrum, axis 1 contains the peaks in
            the MS/MS spectrum. These should be zero-padded,
            such that all the spectra in the batch are the same length.
        ints: torch.Tensor of shape (n_spectra, n_peaks)
            The m/z axis of spectra for which to predict peptide sequences.
            Axis 0 represents an MS/MS spectrum, axis 1 specifies
            the m/z-intensity pair for each peak. These should be zero-padded,
            such that all the spectra in the batch are the same length.
        precursors : torch.Tensor of size (n_spectra, 3)
            The measured precursor mass (axis 0), precursor charge (axis 1), and
            precursor m/z (axis 2) of each MS/MS spectrum.

        Returns
        -------
        pred_peptides : List[List[Tuple[float, np.ndarray, str]]]
            For each spectrum, a list with the top peptide prediction(s). A
            peptide predictions consists of a tuple with the peptide score,
            the amino acid scores, and the predicted peptide sequence.
        """
        memories, mem_masks = self.encoder(mzs, ints)

        # Sizes.
        batch = mzs.shape[0]  # B
        length = self.max_length + 1  # L
        vocab = self.vocab_size  # V 
        beam = self.n_beams  # S

        # Initialize scores and tokens.
        scores = torch.full(
            size=(batch, length, vocab, beam), fill_value=torch.nan
        ).type_as(mzs)
        
        tokens = torch.zeros(batch, length, beam,
                             dtype=torch.int64,
                             device=self.encoder.device)
        
        # Create cache for decoded beams.
        pred_cache = collections.OrderedDict((i, []) for i in range(batch))

        # Get the first prediction.
        pred_masses = torch.zeros(batch, length, 1, device=self.encoder.device) # to store the predicted mass encs in
        pred, pred_masses[:,:1,:] = self.decoder(
            token_masses=torch.zeros(batch, 0, 
                             dtype=torch.int64,
                             device=self.encoder.device),
            memory=memories, 
            memory_key_padding_mask=mem_masks, 
            precursors=precursors
        )

        tokens[:, 0, :] = torch.topk(pred[:, 0, :], beam, dim=1)[1] 
        scores[:, :1, :, :] = einops.repeat(pred, "B L V -> B L V S", S=beam)

        # Make all tensors the right shape for decoding.
        precursors = einops.repeat(precursors, "B L -> (B S) L", S=beam)
        mem_masks = einops.repeat(mem_masks, "B L -> (B S) L", S=beam)
        memories = einops.repeat(memories, "B L V -> (B S) L V", S=beam)
        tokens = einops.rearrange(tokens, "B L S -> (B S) L")
        scores = einops.rearrange(scores, "B L V S -> (B S) L V")

        # The main decoding loop.
        for step in range(0, self.max_length):
            # Terminate beams exceeding the precursor m/z tolerance and track
            # all finished beams (either terminated or stop token predicted).
            (
                finished_beams,
                beam_fits_precursor,
                discarded_beams,
            ) = self._finish_beams(tokens, precursors, step)
            # Cache peptide predictions from the finished beams (but not the
            # discarded beams).
            self._cache_finished_beams(
                tokens,
                scores,
                step,
                finished_beams & ~discarded_beams,
                beam_fits_precursor,
                pred_cache,
                pred_masses,
            )

            # Stop decoding when all current beams have been finished.
            # Continue with beams that have not been finished and not discarded.
            finished_beams |= discarded_beams
            if finished_beams.all():
                break
            # Update the scores.
            scores[~finished_beams, : step + 2, :], pred_masses[~finished_beams, : step + 2, :] = self.decoder(
                token_masses=pred_masses[~finished_beams, : step + 1],
                precursors=precursors[~finished_beams, :],
                memory=memories[~finished_beams, :, :],
                memory_key_padding_mask=mem_masks[~finished_beams, :],
            )
            # Find the top-k beams with the highest scores and continue decoding
            # those.
            tokens, scores = self._get_topk_beams(
                tokens, scores, finished_beams, batch, step + 1
            )
            tokens = tokens
            
        # Return the peptide with the highest confidence score, within the
        # precursor m/z tolerance if possible.
        return list(self._get_top_peptide(pred_cache))

    def _finish_beams(
        self,
        tokens: torch.Tensor,
        precursors: torch.Tensor,
        step: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Track all beams that have been finished, either by predicting the stop
        token or because they were terminated due to exceeding the precursor
        m/z tolerance.

        Parameters
        ----------
        tokens : torch.Tensor of shape (n_spectra * n_beams, max_length)
            Predicted amino acid tokens for all beams and all spectra.
         scores : torch.Tensor of shape
         (n_spectra *  n_beams, max_length, n_amino_acids)
            Scores for the predicted amino acid tokens for all beams and all
            spectra.
        step : int
            Index of the current decoding step.

        Returns
        -------
        finished_beams : torch.Tensor of shape (n_spectra * n_beams)
            Boolean tensor indicating whether the current beams have been
            finished.
        beam_fits_precursor: torch.Tensor of shape (n_spectra * n_beams)
            Boolean tensor indicating if current beams are within precursor m/z
            tolerance.
        discarded_beams : torch.Tensor of shape (n_spectra * n_beams)
            Boolean tensor indicating whether the current beams should be
            discarded (e.g. because they were predicted to end but violate the
            minimum peptide length).
        """
        # Check for tokens with a negative mass (i.e. neutral loss).
        aa_neg_mass_idx = []
        for aa, mass in self.tokenizer.residues.items():
            if mass < 0:
                # aa_neg_mass.append(aa)
                aa_neg_mass_idx.append(self.tokenizer.index[aa])
                
        # Find N-terminal residues.
        n_term = torch.Tensor(
            [
                self.tokenizer.index[aa]
                for aa in self.tokenizer.index
                if aa.startswith(("+", "-",'[+', '[-'))
            ]
        ).to(self.decoder.device)
        
        beam_fits_precursor = torch.zeros(
            tokens.shape[0], dtype=torch.bool
        ).to(self.encoder.device)
        # Beams with a stop token predicted in the current step can be finished.
        finished_beams = torch.zeros(tokens.shape[0], dtype=torch.bool).to(
            self.encoder.device
        )
        ends_stop_token = tokens[:, step] == self.stop_token
        finished_beams[ends_stop_token] = True
        # Beams with a dummy token predicted in the current step can be
        # discarded.
        discarded_beams = torch.zeros(
            tokens.shape[0], dtype=torch.bool
        ).to(self.encoder.device)
        
        discarded_beams[tokens[:, step] == 0] = True
        # Discard beams with invalid modification combinations (i.e. N-terminal
        # modifications occur multiple times or in internal positions).
        if step > 1:  # Only relevant for longer predictions.
            dim0 = torch.arange(tokens.shape[0])
            final_pos = torch.full((ends_stop_token.shape[0],), step)
            final_pos[ends_stop_token] = step - 1
            # Multiple N-terminal modifications.
            multiple_mods = torch.isin(
                tokens[dim0, final_pos], n_term
            ) & torch.isin(tokens[dim0, final_pos - 1], n_term)
            # N-terminal modifications occur at an internal position.
            # Broadcasting trick to create a two-dimensional mask.
            mask = (final_pos - 1)[:, None] >= torch.arange(tokens.shape[1])
            internal_mods = torch.isin(
                torch.where(mask.to(self.encoder.device), tokens, 0), n_term
            ).any(dim=1)
            discarded_beams[multiple_mods | internal_mods] = True

        # Check which beams should be terminated or discarded based on the
        # predicted peptide.
        for i in range(len(finished_beams)):
            # Skip already discarded beams.
            if discarded_beams[i]:
                continue
            pred_tokens = tokens[i][: step + 1]
            peptide_len = len(pred_tokens)
            
            # Omit stop token.
            if self.tokenizer.reverse and pred_tokens[0] == self.stop_token:
                pred_tokens = pred_tokens[1:]
                peptide_len -= 1
            elif not self.tokenizer.reverse and pred_tokens[-1] == self.stop_token:
                pred_tokens = pred_tokens[:-1]
                peptide_len -= 1
            # Discard beams that were predicted to end but don't fit the minimum
            # peptide length.
            if finished_beams[i] and peptide_len < self.min_peptide_len:
                discarded_beams[i] = True
                continue
            # Terminate the beam if it has not been finished by the model but
            # the peptide mass exceeds the precursor m/z to an extent that it
            # cannot be corrected anymore by a subsequently predicted AA with
            # negative mass.
            precursor_charge = precursors[i, 1]
            precursor_mz = precursors[i, 2]
            matches_precursor_mz = exceeds_precursor_mz = False
            
            # Send tokenizer masses to correct device for calculate_precursor_ions()
            self.tokenizer.masses = self.tokenizer.masses.type_as(precursor_mz)
            
            for aa in [None] if finished_beams[i] else aa_neg_mass_idx:
                if aa is None:
                    calc_peptide = pred_tokens
                else:
                    calc_peptide = pred_tokens.detach().clone()
                    calc_peptide = torch.cat(
                        (calc_peptide,
                         torch.tensor([aa]).type_as(calc_peptide)
                        )
                    )
                try:
                    
                    calc_mz = self.tokenizer.calculate_precursor_ions(
                        calc_peptide.unsqueeze(0),
                        precursor_charge.unsqueeze(0)
                    )[0]
                    
                    delta_mass_ppm = [
                        _calc_mass_error(
                            calc_mz,
                            precursor_mz,
                            precursor_charge,
                            isotope,
                        )
                        for isotope in range(
                            self.isotope_error_range[0],
                            self.isotope_error_range[1] + 1,
                        )
                    ]
                    # Terminate the beam if the calculated m/z for the predicted
                    # peptide (without potential additional AAs with negative
                    # mass) is within the precursor m/z tolerance.
                    matches_precursor_mz = aa is None and any(
                        abs(d) < self.precursor_mass_tol
                        for d in delta_mass_ppm
                    )
                    # Terminate the beam if the calculated m/z exceeds the
                    # precursor m/z + tolerance and hasn't been corrected by a
                    # subsequently predicted AA with negative mass.
                    if matches_precursor_mz:
                        exceeds_precursor_mz = False
                    else:
                        exceeds_precursor_mz = all(
                            d > self.precursor_mass_tol for d in delta_mass_ppm
                        )
                        exceeds_precursor_mz = (
                            finished_beams[i] or aa is not None
                        ) and exceeds_precursor_mz
                    if matches_precursor_mz or exceeds_precursor_mz:
                        break
                except KeyError:
                    matches_precursor_mz = exceeds_precursor_mz = False
            # Finish beams that fit or exceed the precursor m/z.
            # Don't finish beams that don't include a stop token if they don't
            # exceed the precursor m/z tolerance yet.
            if finished_beams[i]:
                beam_fits_precursor[i] = matches_precursor_mz
            elif exceeds_precursor_mz:
                finished_beams[i] = True
                beam_fits_precursor[i] = matches_precursor_mz
        return finished_beams, beam_fits_precursor, discarded_beams

    def _cache_finished_beams(
        self,
        tokens: torch.Tensor,
        scores: torch.Tensor,
        step: int,
        beams_to_cache: torch.Tensor,
        beam_fits_precursor: torch.Tensor,
        pred_cache: Dict[
            int, List[Tuple[float, float, np.ndarray, torch.Tensor]]
        ],
        emb: torch.Tensor, # pass predicted embeddings for writing
    ):
        """
        Cache terminated beams.

        Parameters
        ----------
        tokens : torch.Tensor of shape (n_spectra * n_beams, max_length)
            Predicted amino acid tokens for all beams and all spectra.
         scores : torch.Tensor of shape
         (n_spectra *  n_beams, max_length, n_amino_acids)
            Scores for the predicted amino acid tokens for all beams and all
            spectra.
        step : int
            Index of the current decoding step.
        beams_to_cache : torch.Tensor of shape (n_spectra * n_beams)
            Boolean tensor indicating whether the current beams are ready for
            caching.
        beam_fits_precursor: torch.Tensor of shape (n_spectra * n_beams)
            Boolean tensor indicating whether the beams are within the
            precursor m/z tolerance.
        pred_cache : Dict[
                int, List[Tuple[float, float, np.ndarray, torch.Tensor]]
        ]
            Priority queue with finished beams for each spectrum, ordered by
            peptide score. For each finished beam, a tuple with the (negated)
            peptide score, a random tie-breaking float, the amino acid-level
            scores, and the predicted tokens is stored.
        """
        for i in range(len(beams_to_cache)):
            if not beams_to_cache[i]:
                continue
            # Find the starting index of the spectrum.
            spec_idx = i // self.n_beams
            # FIXME: The next 3 lines are very similar as what's done in
            #  _finish_beams. Avoid code duplication?
            pred_tokens = tokens[i][: step + 1]
            # Omit the stop token from the peptide sequence (if predicted).
            has_stop_token = pred_tokens[-1] == self.stop_token
            pred_peptide = pred_tokens[:-1] if has_stop_token else pred_tokens
            # Don't cache this peptide if it was already predicted previously.
            if any(
                torch.equal(pred_cached[-1], pred_peptide)
                for pred_cached in pred_cache[spec_idx]
            ):
                # TODO: Add duplicate predictions with their highest score.
                continue
            scaled_scores = scores[i : i + 1, : step + 1, :] * (2/self.dim_model)
            aa_scores = scaled_scores[0, range(len(pred_tokens)), pred_tokens].tolist()

            # Add an explicit score 0 for the missing stop token in case this
            # was not predicted (i.e. early stopping).
            if not has_stop_token:
                aa_scores.append(0)
            aa_scores = np.asarray(aa_scores)
            # Calculate the updated amino acid-level and the peptide scores.
            aa_scores, peptide_score = _aa_pep_score(
                aa_scores, beam_fits_precursor[i]
            )
            # Omit the stop token from the amino acid-level scores.
            aa_scores = aa_scores[:-1]
            # Add the prediction to the cache (minimum priority queue, maximum
            # the number of beams elements).
            if len(pred_cache[spec_idx]) < self.n_beams:
                heapadd = heapq.heappush
            else:
                heapadd = heapq.heappushpop
            heapadd(
                pred_cache[spec_idx],
                (
                    peptide_score,
                    np.random.random_sample(),
                    aa_scores,
                    torch.clone(pred_peptide),
                    scaled_scores,
                    emb[i : i + 1, : step + 1, :],
                ),
            )

    def _get_topk_beams(
        self,
        tokens: torch.tensor,
        scores: torch.tensor,
        finished_beams: torch.tensor,
        batch: int,
        step: int,
    ) -> Tuple[torch.tensor, torch.tensor]:
        """
        Find the top-k beams with the highest scores and continue decoding
        those.

        Stop decoding for beams that have been finished.

        Parameters
        ----------
        tokens : torch.Tensor of shape (n_spectra * n_beams, max_length)
            Predicted amino acid tokens for all beams and all spectra.
         scores : torch.Tensor of shape
         (n_spectra *  n_beams, max_length, n_amino_acids)
            Scores for the predicted amino acid tokens for all beams and all
            spectra.
        finished_beams : torch.Tensor of shape (n_spectra * n_beams)
            Boolean tensor indicating whether the current beams are ready for
            caching.
        batch: int
            Number of spectra in the batch.
        step : int
            Index of the next decoding step.

        Returns
        -------
        tokens : torch.Tensor of shape (n_spectra * n_beams, max_length)
            Predicted amino acid tokens for all beams and all spectra.
         scores : torch.Tensor of shape
         (n_spectra *  n_beams, max_length, n_amino_acids)
            Scores for the predicted amino acid tokens for all beams and all
            spectra.
        """
        beam = self.n_beams  # S
        vocab = self.vocab_size # V

        # Reshape to group by spectrum (B for "batch").
        tokens = einops.rearrange(tokens, "(B S) L -> B L S", S=beam)
        scores = einops.rearrange(scores, "(B S) L V -> B L V S", S=beam)

        # Get the previous tokens and scores.
        prev_tokens = einops.repeat(
            tokens[:, :step, :], "B L S -> B L V S", V=vocab
        )
        prev_scores = torch.gather(
            scores[:, :step, :, :], dim=2, index=prev_tokens
        )
        prev_scores = einops.repeat(
            prev_scores[:, :, 0, :], "B L S -> B L (V S)", V=vocab
        )

        # Get the scores for all possible beams at this step.
        step_scores = torch.zeros(batch, step + 1, beam * vocab).type_as(
            scores
        )
        step_scores[:, :step, :] = prev_scores
        step_scores[:, step, :] = einops.rearrange(
            scores[:, step, :, :], "B V S -> B (V S)"
        )

        # Find all still active beams by masking out terminated beams.
        active_mask = (
            ~finished_beams.reshape(batch, beam).repeat(1, vocab)
        ).float()
        # Mask out the index '0', i.e. padding token, by default.
        # FIXME: Set this to a very small, yet non-zero value, to only
        # get padding after stop token.
        active_mask[:, :beam] = 1e-8

        # Figure out the top K decodings.
        _, top_idx = torch.topk(step_scores.nanmean(dim=1) * active_mask, beam)
        v_idx, s_idx = np.unravel_index(top_idx.cpu(), (vocab, beam))
        s_idx = einops.rearrange(s_idx, "B S -> (B S)")
        b_idx = einops.repeat(torch.arange(batch), "B -> (B S)", S=beam)

        # Record the top K decodings.
        tokens[:, :step, :] = einops.rearrange(
            prev_tokens[b_idx, :, 0, s_idx], "(B S) L -> B L S", S=beam
        )
        tokens[:, step, :] = torch.tensor(v_idx)
        scores[:, : step + 1, :, :] = einops.rearrange(
            scores[b_idx, : step + 1, :, s_idx], "(B S) L V -> B L V S", S=beam
        )
        scores = einops.rearrange(scores, "B L V S -> (B S) L V")
        tokens = einops.rearrange(tokens, "B L S -> (B S) L")
        return tokens, scores

    def _get_top_peptide(
        self,
        pred_cache: Dict[
            int, List[Tuple[float, float, np.ndarray, torch.Tensor]]
        ],
    ) -> Iterable[List[Tuple[float, np.ndarray, str]]]:
        """
        Return the peptide with the highest confidence score for each spectrum.

        Parameters
        ----------
        pred_cache : Dict[
                int, List[Tuple[float, float, np.ndarray, torch.Tensor]]
        ]
            Priority queue with finished beams for each spectrum, ordered by
            peptide score. For each finished beam, a tuple with the peptide
            score, a random tie-breaking float, the amino acid-level scores,
            and the predicted tokens is stored.

        Returns
        -------
        pred_peptides : Iterable[List[Tuple[float, np.ndarray, str]]]
            For each spectrum, a list with the top peptide prediction(s). A
            peptide predictions consists of a tuple with the peptide score,
            the amino acid scores, and the predicted peptide sequence.
        """
        for peptides in pred_cache.values():
            if len(peptides) > 0:
                yield [
                    (
                        pep_score,
                        aa_scores,
                        pred_tokens,
                        pos_wise_scores,
                        emb,
                    )
                    for pep_score, _, aa_scores, pred_tokens, pos_wise_scores, emb in heapq.nlargest(
                        self.top_match, peptides
                    )
                ]
            else:
                yield []

    def _process_batch(self, batch):
        """ Prepare batch returned from AnnotatedSpectrumDataset of the 
            latest depthcharge version

        Each batch is a dict and contains these keys: 
             ['peak_file', 'scan_id', 'ms_level', 'precursor_mz',
             'precursor_charge', 'mz_array', 'intensity_array',
             'seq']
        Returns
        -------
        spectra : torch.Tensor of shape (batch_size, n_peaks, 2)
            The padded mass spectra tensor with the m/z and intensity peak values
            for each spectrum.
        precursors : torch.Tensor of shape (batch_size, 3)
            A tensor with the precursor neutral mass, precursor charge, and
            precursor m/z.
        seqs : np.ndarray
            The spectrum identifiers (during de novo sequencing) or peptide
            sequences (during training).

        """
        # Squeeze torch tensors in first dimension
        for k in batch.keys():
            try:
                batch[k]= batch[k].squeeze(0)
            except:
                continue

        precursor_mzs = batch["precursor_mz"]
        precursor_charges = batch["precursor_charge"]
        precursor_masses = (precursor_mzs - 1.007276) * precursor_charges
        precursors = torch.vstack([precursor_masses, 
                                   precursor_charges, precursor_mzs] ).T #.float()

        mzs, ints = batch['mz_array'], batch['intensity_array']

        seqs = batch['seq'] if "seq" in batch else None
        token_masses = batch['delta_masses'].to(self.device) if "delta_masses" in batch else None

        if token_masses is None and seqs is not None:
            mass_tensor = self.tokenizer.masses.detach().clone().to(self.device)
            token_masses = mass_tensor[seqs]
        
        return mzs, ints, precursors, seqs, token_masses

    def _forward_step_enc(self, batch):
        mzs, ints, precursors, tokens, true_masses = self._process_batch(batch)
        memories, mem_masks = self.encoder(mzs, ints)

        return memories, mem_masks, precursors, tokens, true_masses

    def _forward_step_dec(self,
                          memories,
                          mem_masks,
                          precursors,
                          true_masses):
        decoded, pred_token_masses = self.decoder(
            token_masses=true_masses[:, :-1],
            memory=memories, 
            memory_key_padding_mask=mem_masks, 
            precursors=precursors
        )

        return decoded, pred_token_masses

    def _forward_step(
        self,
        batch,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        The forward learning step.

        Parameters
        ----------
        batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]
            A batch of (i) m/z values of MS/MS spectra, 
            (ii) intensity values of MS/MS spectra,
            (iii) precursor information, 
            (iv) peptide sequences as torch Tensors.

        Returns
        -------
        scores : torch.Tensor of shape (n_spectra, length, n_amino_acids)
            The individual amino acid scores for each prediction.
        tokens : torch.Tensor of shape (n_spectra, length)
            The predicted tokens for each spectrum.
        """
        memories, mem_masks, precursors, tokens, true_masses = self._forward_step_enc(batch)
        decoded, pred_token_masses = self._forward_step_dec(memories,mem_masks,precursors,true_masses)

        return decoded, pred_token_masses, tokens, true_masses

    def training_step(
        self,
        batch_es: dict,
        *args,
        mode: str = "train",
    ) -> torch.Tensor:
        """
        A single training step.

        Parameters
        ----------
        batch_es : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]
            A batch of (i) m/z values of MS/MS spectra, 
            (ii) intensity values of MS/MS spectra,
            (iii) precursor information, 
            (iv) peptide sequences as torch Tensors.
        mode : str
            Logging key to describe the current stage.

        Returns
        -------
        torch.Tensor
            The loss of the training step.
        """
        if type(batch_es) == dict and "peak_file" in batch_es: # the batch is a single batch
            batch_es = {'': batch_es}
        elif len(batch_es) != 2:
            raise RuntimeError(f"The step batch was a dict but with more than 2 entries! {batch_es}")

        losses = []
        bs = []
        for b_type, batch in batch_es.items():
            # Record the loss.
            if batch is None: # Can happen when one Dataloader is exhausted before the other
                continue
            _, pred_masses, truth, true_masses = self._forward_step(batch)

            # Problem: Lengths of the predicted and true peptides do not necessarily match
            # Pad either predictions or truths (whichever is shorter)
            padd_right = abs(true_masses.shape[1] - pred_masses.shape[1])
            if true_masses.shape[1] < pred_masses.shape[1]:
                if truth is not None:
                    truth = torch.nn.functional.pad(truth, (0,padd_right), mode="constant", value=0)
                true_masses = torch.nn.functional.pad(true_masses, (0,padd_right), mode="constant", value=0)
            else:
                pred_masses = torch.nn.functional.pad(pred_masses, (0,0,0,padd_right), mode="constant", value=0) # need to padd second to last (i.e. add mass enc of only 0s)

            pred_masses = pred_masses.squeeze(-1).squeeze(-1)
            mask = true_masses == 0
            
            # batch_mask = batch.pop("mask", None) # True For every seen token, False for ervery unseen token
            # if batch_mask is not None:
            #     padd_right = abs(mask.shape[1] - batch_mask.shape[1])
            #     batch_mask = torch.nn.functional.pad(batch_mask, (0,padd_right), mode="constant", value=0)

            #     mask |= batch_mask # Do also ignore all seen tokens

            mse_loss = torch.mean((pred_masses[~mask] - true_masses[~mask])**2)
            losses += [mse_loss]
            bs += [pred_masses.shape[0]]

            if "train" in mode:
                self.log(
                    f"train_MSELoss",
                    mse_loss.detach(),
                    on_step=True,
                    on_epoch=False,
                    sync_dist=True,
                    batch_size=pred_masses.shape[0]
                )
            else:
                self.log(
                    f"validation_dev_loss",
                    mse_loss.detach(),
                    on_step=False,
                    on_epoch=True,
                    sync_dist=True,
                    batch_size=pred_masses.shape[0]
                )


        if self.logger is not None and hasattr(self.logger, "experiment"):
            if "valid" in mode and self.current_val != self.last_val: # only plot during val and first batch
                self.last_val = self.current_val

                preds_np = pred_masses.detach().cpu().numpy()
                true_np = true_masses.detach().cpu().numpy()
                # Combine all positions into a single scatter plot
                fig, ax = plt.subplots()

                # Flatten the arrays for all positions
                true_flat = true_np.flatten()
                pred_flat = preds_np.flatten()

                # Apply mask to filter padding
                mask = true_flat > 0
                ax.scatter(true_flat[mask], pred_flat[mask], alpha=0.5)

                ax.set_xlabel("True Mass Deltas")
                ax.set_ylabel("Predicted Masses")
                ax.set_title("Predicted vs. True Mass Deltas (All Positions Combined)")
                ax.grid()

                # Convert plot to tensor
                buf = BytesIO()
                plt.savefig(buf, format="jpeg")
                plt.close(fig)  # Close figure to free memory
                buf.seek(0)

                image = Image.open(buf).convert("RGB")
                image_np = np.array(image, dtype=np.uint8)
                image_tensor = torch.from_numpy(image_np).permute(2, 0, 1)  # Convert to PyTorch tensor

                # Log image to TensorBoard
                self.logger.experiment.add_image("Scatter/pred_vs_true_all_positions", image_tensor, self.global_step)

        return losses[0] if len(losses) == 1 else (losses[0]*(bs[0]/(bs[0]+bs[1])) + losses[1]*(bs[1]/(bs[0]+bs[1])))

    def validation_step(
        self, batch_es: dict | Tuple[torch.Tensor, torch.Tensor, List[str]], *args
    ) -> torch.Tensor:
        """
        A single validation step.

        Parameters
        ----------
        batch_es : dict | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]
            A batch (or a dict of batches) of
            (i) m/z values of MS/MS spectra, 
            (ii) intensity values of MS/MS spectra,
            (iii) precursor information, 
            (iv) peptide sequences as torch Tensors.

        Returns
        -------
        torch.Tensor
            The loss of the validation step.
        """
        losses = []
        seen_coeff = 0.2

        if type(batch_es) == dict and "peak_file" in batch_es: # the batch is a single batch
            batch_es = {'': batch_es}
        elif len(batch_es) != 2:
            raise RuntimeError(f"The validation batch was a dict but with more than 2 entries! {batch_es}")

        for b_type, batch in batch_es.items():
            # Record the loss.
            if batch is None: # Can happen in episodic training when on Dataloader is exhausted before the other
                continue
            losses += [self.training_step(batch, mode=f"valid{b_type}")]
            # print(f"Loss for {b_type}: {losses[-1]}")

            if self.calculate_precision:
                # Calculate and log amino acid and peptide match evaluation metrics from
                # the predicted peptides.
                peptides_true = [''.join(p) for p in self.tokenizer.detokenize(batch['seq'], join=False)]
                peptides_pred = []
                for spectrum_preds in self.forward(batch):
                    for _, _, pred, _, _ in spectrum_preds:
                        peptides_pred.append(pred)
                peptides_pred = [''.join(p) for p in self.tokenizer.detokenize(peptides_pred, join=False)]
                batch_size = len(peptides_true)
                aa_precision, _, pep_precision = evaluate.aa_match_metrics(
                    *evaluate.aa_match_batch(
                        peptides_true,
                        peptides_pred,
                        self.tokenizer.residues,
                    )
                )
                
                log_args = dict(on_step=False, on_epoch=True, sync_dist=True)
                self.log(
                    f"pep_precision{b_type}",
                    pep_precision,
                    **log_args,
                    batch_size=batch_size
                )
                self.log(
                    f"aa_precision{b_type}",
                    aa_precision,
                    **log_args,
                    batch_size=batch_size
                )

        if len(losses) == 2:
            loss = losses[0]*seen_coeff + losses[1]*(1-seen_coeff)

            try:
                batch['seq']= batch['seq'].squeeze(0)
            except:
                pass
            batch_size = len(batch['seq'])
            
            self.log(
                    f"valid_MSELoss",
                    loss.detach(),
                    on_step=False,
                    on_epoch=True,
                    sync_dist=True,
                    batch_size=batch_size
                )
        else:
            loss = losses[0]
        return loss


    def predict_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args
    ) -> List[Tuple[np.ndarray, float, float, str, float, np.ndarray]]:
        """
        A single prediction step.

        Parameters
        ----------
        batch : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[str]]
            A batch of (i) m/z values of MS/MS spectra, 
            (ii) intensity values of MS/MS spectra,
            (iii) precursor information, 
            (iv) peptide sequences as torch Tensors.

        Returns
        -------
        predictions: List[Tuple[np.ndarray, float, float, str, float, np.ndarray]]
            Model predictions for the given batch of spectra containing spectrum
            ids, precursor information, peptide sequences as well as peptide
            and amino acid-level confidence scores.
        """

        _, _, precursors, true_seqs, _ = self._process_batch(batch)
        true_seqs = (
            [''.join(p) for p in self.tokenizer.detokenize(true_seqs, join=False)]
            if true_seqs is not None else ['']*precursors.shape[0]
        )

        prec_charges = precursors[:, 1].cpu().detach().numpy()
        prec_mzs = precursors[:, 2].cpu().detach().numpy()

        predictions = []
        for (
            precursor_charge,
            precursor_mz,
            scan,
            title,
            file_name,
            true_seq,
            spectrum_preds,
        ) in zip(
            prec_charges,
            prec_mzs,
            batch["scans"],
            batch["title"],
            batch["peak_file"],
            true_seqs,
            self.forward(batch)
        ):
            for peptide_score, aa_scores, peptide, pos_wise_scores, pred_masses in spectrum_preds:
                predictions.append(
                    (
                        scan,
                        precursor_charge,
                        precursor_mz,
                        peptide,
                        peptide_score,
                        aa_scores,
                        file_name,
                        true_seq,
                        title,
                        None, # pos_wise_scores,
                        pred_masses,
                    )
                )

        return predictions
    
    def on_train_epoch_start(self):
        # Track the training epoch for the dataloaders -> episodic training
        self.epoch_tracker.increase()
        return super().on_train_epoch_start()

    def on_train_epoch_end(self) -> None:
        """
        Log the training loss at the end of each epoch.
        """
        if "train_MSELoss" in self.trainer.callback_metrics:
            train_loss = self.trainer.callback_metrics["train_MSELoss"].detach().item()
        else:
            train_loss = np.nan
        metrics = {
            "step": self.trainer.global_step,
            "train": train_loss,
        }
        self._history.append(metrics)
        self._log_history()

    def on_validation_epoch_end(self) -> None:
        """
        Log the validation metrics at the end of each epoch.
        """
        if "valid_MSELoss" in self.trainer.callback_metrics:
            valid_loss = self.trainer.callback_metrics["valid_MSELoss"].detach().item()
        else:   
            valid_loss = np.nan

        callback_metrics = self.trainer.callback_metrics
        metrics = {
            "step": self.trainer.global_step,
            "valid": valid_loss,
        }

        if self.calculate_precision:
            metrics["valid_aa_precision"] = (
                callback_metrics["aa_precision"].detach().item()
            )
            metrics["valid_pep_precision"] = (
                callback_metrics["pep_precision"]
                .detach()
                .item()
            )
        self._history.append(metrics)
        self._log_history()

        self.current_val += 1

    def on_predict_batch_end(
        self,
        outputs: List[Tuple[np.ndarray, List[str], torch.Tensor]],
        *args,
    ) -> None:
        """
        Write the predicted peptide sequences and amino acid scores to the
        output file.
        """
        if self.out_writer is None:
            return
        # Triply nested lists: results -> batch -> step -> spectrum.
        for (
            scan,
            charge,
            precursor_mz,
            peptide,
            peptide_score,
            aa_scores,
            file_name,
            true_seq,
            title,
            pos_wise_scores,
            pred_masses
        ) in outputs:
            if len(peptide) == 0:
                continue

            # Compute mass and detokenize
            calc_mass = self.tokenizer.calculate_precursor_ions(
                peptide.unsqueeze(0),
                torch.tensor([charge]).type_as(peptide)
            )[0]
            peptide = ''.join(
                self.tokenizer.detokenize(peptide.unsqueeze(0), join=False)[0]
            )

            # if self.tokenizer.reverse: #TODO
            #     aa_scores = aa_scores[::-1]

            self.out_writer.psms.append(
                (
                    peptide,
                    scan,
                    peptide_score,
                    charge,
                    precursor_mz,
                    calc_mass,
                    ",".join(list(map("{:.5f}".format, aa_scores))),
                    file_name,
                    true_seq,
                    title,
                    pred_masses
                ),
            )

            # self.out_writer.all_pos_wise_scores.append(pos_wise_scores)
            # self.out_writer.all_embs.append(pred_masses)

    def on_train_start(self):
        """Log optimizer settings."""
        self.log("hp/optimizer_warmup_iters", self.warmup_iters)
        self.log("hp/optimizer_cosine_schedule_period_iters", self.cosine_schedule_period_iters)

    def _log_history(self) -> None:
        """
        Write log to console, if requested.
        """
        # Log only if all output for the current epoch is recorded.
        if len(self._history) == 0:
            return
        if len(self._history) == 1:
            header = "Step\tTrain loss\tValid loss\t"
            if self.calculate_precision:
                header += "Peptide precision\tAA precision"

            logger.info(header)
        metrics = self._history[-1] # TODO fix
        if metrics["step"] % self.n_log == 0:
            msg = "%i\t%.6f\t%.6f"
            vals = [
                metrics["step"],
                metrics.get("train", np.nan),
                metrics.get("valid", np.nan),
            ]

            if self.calculate_precision:
                msg += "\t%.6f\t%.6f"
                vals += [
                    metrics.get("valid_pep_precision", np.nan),
                    metrics.get("valid_aa_precision", np.nan),
                ]

            logger.info(msg, *vals)

    def configure_optimizers(
        self,
    ) -> Tuple[torch.optim.Optimizer, Dict[str, Any]]:
        """
        Initialize the optimizer.

        This is used by pytorch-lightning when preparing the model for training.

        Returns
        -------
        Tuple[torch.optim.Optimizer, Dict[str, Any]]
            The initialized Adam optimizer and its learning rate scheduler.
        """
        optimizer = torch.optim.Adam(self.parameters(), **self.opt_kwargs)
        # Apply learning rate scheduler per step.
        lr_scheduler = CosineWarmupScheduler(
            optimizer, self.warmup_iters, self.cosine_schedule_period_iters
        )
        return [optimizer], {"scheduler": lr_scheduler, "interval": "step"}


class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, pred, target, mask, reduction='mean'):
        out = (target[~mask]-pred[~mask])**2
        if reduction == "mean":
            return out.mean()
        elif reduction == "None":
            return out


class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
    """
    Learning rate scheduler with linear warm-up followed by cosine shaped decay.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer
        Optimizer object.
    warmup_iters : int
        The number of iterations for the linear warm-up of the learning rate.
    cosine_schedule_period_iters : int
        The number of iterations for the cosine half period of the learning rate.
    """

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_iters: int,
        cosine_schedule_period_iters: int,
    ):
        self.warmup_iters = warmup_iters
        self.cosine_schedule_period_iters = cosine_schedule_period_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (
            1 + np.cos(np.pi * epoch / self.cosine_schedule_period_iters)
        )
        if epoch <= self.warmup_iters:
            lr_factor *= epoch / self.warmup_iters
        return lr_factor


def _calc_mass_error(
    calc_mz: float, obs_mz: float, charge: int, isotope: int = 0
) -> float:
    """
    Calculate the mass error in ppm between the theoretical m/z and the observed
    m/z, optionally accounting for an isotopologue mismatch.

    Parameters
    ----------
    calc_mz : float
        The theoretical m/z.
    obs_mz : float
        The observed m/z.
    charge : int
        The charge.
    isotope : int
        Correct for the given number of C13 isotopes (default: 0).

    Returns
    -------
    float
        The mass error in ppm.
    """
    return (calc_mz - (obs_mz - isotope * 1.00335 / charge)) / obs_mz * 10**6


def _aa_pep_score(
    aa_scores: np.ndarray, fits_precursor_mz: bool
) -> Tuple[np.ndarray, float]:
    """
    Calculate amino acid and peptide-level confidence score from the raw amino
    acid scores.

    The peptide score is the mean of the raw amino acid scores. The amino acid
    scores are the mean of the raw amino acid scores and the peptide score.

    Parameters
    ----------
    aa_scores : np.ndarray
        Amino acid level confidence scores.
    fits_precursor_mz : bool
        Flag indicating whether the prediction fits the precursor m/z filter.

    Returns
    -------
    aa_scores : np.ndarray
        The amino acid scores.
    peptide_score : float
        The peptide score.
    """
    peptide_score = np.mean(aa_scores)
    aa_scores = (aa_scores + peptide_score) / 2
    if not fits_precursor_mz:
        peptide_score -= 1
    return aa_scores, peptide_score

def generate_tgt_mask(sz: int) -> torch.Tensor:
    """Generate a square mask for the sequence.

    Parameters
    ----------
    sz : int
        The length of the target sequence.
    """
    return ~torch.triu(torch.ones(sz, sz, dtype=torch.bool)).transpose(0, 1)
