"""Data loaders for the de novo sequencing task."""
import os
from typing import Optional, Iterable, Tuple
from ..utils import EpochTracker
from pathlib import Path
import lightning.pytorch as pl
from lightning.pytorch.utilities import CombinedLoader
import numpy as np
import torch
from torch.utils.data import DataLoader
import tempfile
import pyarrow as pa
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe


from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.data import (
                                AnnotatedSpectrumDataset,
                                CustomField,
                                SpectrumDataset,
                                preprocessing
)
from .datasets import SplitAwareAnnotatedSpectrumDataset, SplitAwareSpectrumDataset, SimulatedSpectraContainer, SimulatedSpectrumDataset


class DeNovoDataModule(pl.LightningDataModule):
    """
    Data loader to prepare MS/MS spectra for a Spec2Pep predictor.

    Parameters
    ----------
    train_paths : str, optional
            A spectrum lance path for model training.
    valid_pathas : str, optional
        A spectrum lance path for validation.
    test_paths : str, optional
        A spectrum lance path for evaluation or inference.
    train_batch_size : int
        The batch size to use for training.
    eval_batch_size : int
        The batch size to use for inference.
    n_peaks : Optional[int]
        The number of top-n most intense peaks to keep in each spectrum. `None`
        retains all peaks.
    min_mz : float
        The minimum m/z to include. The default is 140 m/z, in order to exclude
        TMT and iTRAQ reporter ions.
    max_mz : float
        The maximum m/z to include.
    min_intensity : float
        Remove peaks whose intensity is below `min_intensity` percentage of the
        base peak intensity.
    remove_precursor_tol : float
        Remove peaks within the given mass tolerance in Dalton around the
        precursor mass.
    n_workers : int, optional
        The number of workers to use for data loading. By default, the number of
        available CPU cores on the current machine is used.
    max_charge: int
        Remove PSMs which precursor charge higher than specified max_charge
    tokenizer: Optional[PeptideTokenizer] 
        Peptide tokenizer for tokenizing sequences
    random_state : Optional[int]
        The NumPy random state. ``None`` leaves mass spectra in the order they
        were parsed.
    shuffle: Optional[bool]
        Should the training dataset be shuffled? Suffling based on specified buffer_size
    buffer_size: Optional[int]
        See more here: 
        https://huggingface.co/docs/datasets/v1.11.0/dataset_streaming.html#shuffling-the-dataset-shuffle
    """

    def __init__(
        self,
        train_paths: Optional[Iterable[str]] = None,
        valid_paths: Optional[Iterable[str]] = None,
        test_paths: Optional[str] = None,
        train_indices: Optional[str] = None,
        valid_indices: Optional[str] = None,
        test_indices: Optional[str] = None,
        train_batch_size: int = 128,
        eval_batch_size: int = 1028,
        n_peaks: Optional[int] = 150,
        min_mz: float = 50.0,
        max_mz: float = 2500.0,
        min_intensity: float = 0.01,
        remove_precursor_tol: float = 2.0,
        n_workers: Optional[int] = None,
        random_state: Optional[int] = None,
        max_charge: Optional[int] = 10,
        tokenizer: Optional[PeptideTokenizer] = None,
        lance_dir: Optional[str] = None,
        shuffle: Optional[bool] = True,
        buffer_size: Optional[int] = 100_000,
        residue_counts_path: Optional[str] = None,
        episodic_distributions: Optional[dict] = None,
        epoch_tracker: Optional[EpochTracker] = None,
        proportion_real_data: Optional[float] = None,
    ):
        super().__init__()
        self.train_paths = train_paths
        self.valid_paths = valid_paths
        self.test_paths = test_paths
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.train_indices = train_indices
        self.valid_indices = valid_indices
        self.test_indices = test_indices

        self.residue_counts_path = residue_counts_path
        self.episodic_distributions = episodic_distributions
        self.epoch_tracker = epoch_tracker

        self.tokenizer = tokenizer if tokenizer is not None else PeptideTokenizer()
        self.lance_dir = lance_dir if lance_dir is not None else tempfile.TemporaryDirectory(suffix='.lance').name 

        self.proportion_real_data = proportion_real_data

        self.train_dataset = None
        self.valid_dataset = None
        self.inverse_valid_dataset = None
        self.test_dataset = None

        self.n_workers = n_workers if n_workers is not None else os.cpu_count()
        self.shuffle = shuffle if shuffle else None  # set to None if not wanted. Otherwise torch throws and error
        self.buffer_size = buffer_size

        self.valid_charge = np.arange(1, max_charge+1)
        self.preprocessing_fn = [
            preprocessing.set_mz_range(min_mz=min_mz, max_mz=max_mz),
            preprocessing.remove_precursor_peak(remove_precursor_tol, "Da"),
            preprocessing.filter_intensity(min_intensity, n_peaks),
            preprocessing.scale_intensity("root", 1),
            scale_to_unit_norm
            ]
        self.custom_field_test_mgf = [
            CustomField("scans",
                        lambda x: x["params"]["scans"] if 'scans' in x["params"] else x["params"]["title"],
                        pa.string()),
            CustomField("title",
                        lambda x: x["params"]["title"],
                        pa.string())
        ]
        self.custom_field_test_mzml = [
            CustomField("scans", lambda x: x["id"], pa.string()),
            CustomField("title", lambda x: x["id"], pa.string()),
        ]
        
        self.custom_field_anno = [CustomField("seq", lambda x: x["params"]["seq"], pa.string())]

    def make_dataset(self,
                     paths,
                     annotated,
                     mode,
                     shuffle,
                     indices_in_split: str | None = None,
                     residue_counts_path: str | None = None,
                     episodic_distributions: dict | None = None,
                     is_inverse_episode: bool | None = None,
                     ):
        """
        Make spectrum datasets
        Parameters
        ----------
        paths : Iterable[str]
            Paths to input datasets
        annotated: bool
            True if peptide sequence annotations are available for the test
            data.
        mode: str {"train", "valid", "test"}
            The mode indicating name of lance instance  
        shuffle: bool
            Indicates whether to shuffle training data based on buffer_size
        indices_in_split: str | None
            Path to pickled numpy array of all indice in the lance dataset to consider in this dataset. This is relevant, when the same lance file is the basis for different splits, but only a subset of the dataset is considered in each split. If set to None, the full dataset will be considered
        residue_counts_path: str | None
            Path to a matrix counting the occurences of all residues. Also contains a map from the string representation of the residue to its index in the matrix
        episodic_distributions: dict | None
            The distributions of the residues in a given episode i.e. epoch
        is_inverse_episode: bool | None
            Flag to indicate if the inverse distribution should be used. This is used for the episodic training, where the inverse dataloader will load data that contains at least one unseen residue (which thus was black-listed in the train and val set).
        """
        custom_fields = self.custom_field_anno if annotated else []
        
        if mode=="test":
            if all([Path(f).suffix in ('.mgf') for f in paths]):
                custom_fields = custom_fields + self.custom_field_test_mgf
            if all([Path(f).suffix in (".mzml",  ".mzxml", '.mzML') for f in paths]):
                custom_fields = custom_fields + self.custom_field_test_mzml
            
        lance_path = f'{self.lance_dir}/{mode}.lance'
        
        parse_kwargs = dict(
            preprocessing_fn=self.preprocessing_fn,
            custom_fields=custom_fields,
            valid_charge=self.valid_charge,
        )

        dataset_params = dict(
            batch_size=self.train_batch_size if mode=="train" else self.eval_batch_size
        )

        if indices_in_split != None: # I.e. we use the custom split
            dataset_params |= dict(
                shuffle=shuffle,
                split_indices = indices_in_split,
            )

        if residue_counts_path is None and episodic_distributions != residue_counts_path:
            raise AttributeError(f"Path to a residue matrix was provided while the episodic distributions were not (or the other way around). residue_counts_path: {residue_counts_path}, episodic_distributions: {episodic_distributions}")
        if residue_counts_path is not None:
            dataset_params |= dict(
                residue_counts_path=residue_counts_path,
                episodic_distributions=episodic_distributions,
                epoch=self.epoch_tracker,
                invert_episode=is_inverse_episode,
            )
            
        anno_dataset_params = dataset_params | dict(
            tokenizer=self.tokenizer,
            annotations='seq',
        )

        Anno_Data_Set = AnnotatedSpectrumDataset if indices_in_split == None else SplitAwareAnnotatedSpectrumDataset
        Data_Set = SpectrumDataset if indices_in_split == None else SplitAwareSpectrumDataset

        if any([Path(f).suffix in (".lance") for f in paths]):
            if annotated:
                dataset = Anno_Data_Set.from_lance(paths[0], **anno_dataset_params)
            else:
                dataset = Data_Set.from_lance(paths[0], **dataset_params)
        else:
            if annotated:
                dataset = Anno_Data_Set(
                    spectra=paths,
                    path=lance_path,
                    parse_kwargs=parse_kwargs,
                    **anno_dataset_params,
                )
            else:
                dataset = Data_Set(
                    spectra=paths,
                    path=lance_path,
                    parse_kwargs=parse_kwargs,
                    **dataset_params,
                )

        # if shuffle and indices_in_split == None:
        #     dataset = ShufflerIterDataPipe(
        #         dataset,
        #         buffer_size=self.buffer_size
        #     )

        print(f"Dataset len: {dataset.n_spectra}")

        return dataset

    def setup(self, stage: str = None, annotated: bool = True) -> None:
        """
        Set up the PyTorch Datasets.

        Parameters
        ----------
        stage : str {"fit", "validate", "test"}
            The stage indicating which Datasets to prepare. All are prepared by
            default.
        annotated: bool
            True if peptide sequence annotations are available for the test
            data.
        """
        if stage in (None, "fit", "validate"):
            if self.train_paths is not None:
                o_train_bs = self.train_batch_size
                self.train_batch_size = int(self.train_batch_size*self.proportion_real_data) # set this to later have it correct in the make dataset call
                self.train_dataset = self.make_dataset(
                    self.train_paths,
                    annotated=True,
                    mode='train',
                    shuffle=self.shuffle,
                    indices_in_split=self.train_indices,
                    residue_counts_path=self.residue_counts_path,
                    episodic_distributions=self.episodic_distributions,
                    is_inverse_episode=False
                )

                min_mz = 60
                max_mz = 300

                min_peaks = 5
                max_peaks = 20

                min_num_noise_peaks = 0
                max_num_noise_peaks = 10

                charge_probs = {
                    1: 0.5,
                    2: 0.25,
                    3: 0.125,
                    4: 0.125,
                }

                min_N_missing_peaks = 0
                max_N_missing_peaks = 5

                container = SimulatedSpectraContainer(self.tokenizer,
                                                        min_mz=min_mz,
                                                        max_mz=max_mz,
                                                        min_peaks=min_peaks,
                                                        max_peaks=max_peaks,
                                                        min_N_noise_peaks=min_num_noise_peaks,
                                                        max_N_noise_peaks=max_num_noise_peaks,
                                                        charge_probs=charge_probs,
                                                        min_N_missing_peaks=min_N_missing_peaks,
                                                        max_N_missing_peaks=max_N_missing_peaks,
                                                        device=torch.empty(1).device
                                                        )

                self.train_dataset_simulation = SimulatedSpectrumDataset(
                                                    simulation_container=container,
                                                    reverse_peps=self.tokenizer.reverse,
                                                    batch_size=int(o_train_bs*(1-self.proportion_real_data)),
                                                )
            if self.valid_paths is not None:
                self.valid_dataset = self.make_dataset(
                    self.valid_paths,
                    annotated=True,
                    mode='valid',
                    shuffle=False,
                    indices_in_split=self.valid_indices,
                    residue_counts_path=self.residue_counts_path,
                    episodic_distributions=self.episodic_distributions,
                    is_inverse_episode=False
                )

                self.inverse_valid_dataset = self.make_dataset(
                    self.valid_paths,
                    annotated=True,
                    mode='valid',
                    shuffle=False,
                    indices_in_split=self.valid_indices,
                    residue_counts_path=self.residue_counts_path,
                    episodic_distributions=self.episodic_distributions,
                    is_inverse_episode=True
                )
        if stage in (None, "test"):
            if self.test_paths is not None:
                self.test_dataset = self.make_dataset(
                    self.test_paths,
                    annotated=annotated,
                    mode='test',
                    shuffle=False,
                    indices_in_split=self.test_indices
                )

    def _make_loader(
        self,
        dataset: torch.utils.data.Dataset,
        shuffle: Optional[bool] = None,
    ) -> torch.utils.data.DataLoader:
        """
        Create a PyTorch DataLoader.  
        Parameters
        ----------
        dataset : torch.utils.data.Dataset
            A PyTorch Dataset.
        batch_size : int
            The batch size to use.
        shuffle : bool
            Option to shuffle the batches.

        Returns
        -------
        torch.utils.data.DataLoader
            A PyTorch DataLoader.
        """
        return DataLoader(
            dataset,
            shuffle=None, # shuffleing is done in datasets directly
            num_workers=0,  # self.n_workers,
            #precision=torch.float32,
            pin_memory=True,
        )

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        """Get the training DataLoader."""
        if self.proportion_real_data is None or self.proportion_real_data==1:
            return self._make_loader(self.train_dataset, self.shuffle)
        else:
            return CombinedLoader({
                'real': self._make_loader(self.train_dataset, self.shuffle),
                'simulated': self._make_loader(self.train_dataset_simulation)
                }, mode='min_size')

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        """Get the validation DataLoader."""
        if self.episodic_distributions is not None: # It is an episodic training
            return CombinedLoader({'seen': self._make_loader(self.valid_dataset), 'unseen': self._make_loader(self.inverse_valid_dataset)}, mode='max_size')
        else:
            return self._make_loader(self.valid_dataset)

    def test_dataloader(self) -> torch.utils.data.DataLoader:
        """Get the test DataLoader."""
        return self._make_loader(self.test_dataset)

    def predict_dataloader(self) -> torch.utils.data.DataLoader:
        """Get the predict DataLoader."""
        return self._make_loader(self.test_dataset)


def scale_to_unit_norm(spectrum):
    """
    Scaling function used in Casanovo
    slightly differing from the depthcharge implementation
    """
    spectrum._inner._intensity = spectrum.intensity / np.linalg.norm(
                spectrum.intensity
            )
    return spectrum

