from depthcharge.data import (
                                AnnotatedSpectrumDataset,
                                SpectrumDataset
)
from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.constants import H2O, HYDROGEN

from collections.abc import Iterable
from os import PathLike
from ..utils import EpochTracker
import polars as pl
import numpy as np
from typing import Any
import pickle as pkl
from torch import isin
from torch.utils.data.dataset import IterableDataset
import torch
from scipy.stats import gamma

MULTI_PTM_TOKEN_MASS_SET = {
    57.021464, # G
    71.037114, # A
    87.032028, # S
    97.052764, # P
    99.068414, # V
    101.04767, # T
    160.030649, # C[Carbamidomethyl]
    113.084064, # I
    113.084064, # L
    114.042927, # N
    115.026943, # D
    128.058578, # Q
    128.094963, # K
    129.042593, # E
    131.040485, # M
    137.058912, # H
    147.068414, # F
    156.101111, # R
    163.063329, # Y
    186.079313, # W
    147.0354, # M[Oxidation]
    115.026943, # N[Deamidated]
    129.042593, # Q[Deamidated]
    111.03202800000001, # E[-18.011]
    242.13789000000003, # K[+114.043]
    170.105528, # K[+42.011]
    142.110613, # K[+14.016]
    111.03202900000001, # Q[-17.027]
    170.116761, # R[+14.016]
    157.085127, # R[+0.984]
    166.998359, # S[+79.966]
    290.111401, # S[+203.079]
    181.01400999999998, # T[+79.966]
    304.127052, # T[+203.079]
    243.02966, # Y[+79.966]
    156.089878, # K[+27.995]
    198.136828, # K[+70.042]
    196.121178, # K[+68.026]
    242.12665700000002, # K[+114.032]
    214.131742, # K[+86.037]
    156.126263, # K[+28.031]
    170.14191300000002, # K[+42.047]
    184.12117800000001, # K[+56.026]
    228.111007, # K[+100.016]
    214.095357, # K[+86.000]
    113.047679, # P[+15.995]
    184.132411, # R[+28.031]
    208.048407, # Y[+44.985]
    357.257895, # K[+229.163]
}

class SimulatedSpectraContainer:
    def __init__(self,
                 tokenizer: PeptideTokenizer,
                 min_mz: float,
                 max_mz: float,
                 min_peaks: int,
                 max_peaks: int,
                 min_N_noise_peaks: int,
                 max_N_noise_peaks: int,
                 charge_probs: dict,
                 min_N_missing_peaks: int,
                 max_N_missing_peaks: int,
                 device,):
        self.device = device
        self.min_mz = min_mz
        self.max_mz = max_mz
        self.min_peaks = min_peaks
        self.max_peaks = max_peaks
        self.min_N_noise_peaks = min_N_noise_peaks
        self.max_N_noise_peaks = max_N_noise_peaks
        self.min_N_missing_peaks = min_N_missing_peaks
        self.max_N_missing_peaks = max_N_missing_peaks

        if min_N_missing_peaks >= min_peaks:
            raise AttributeError(f"The minimum number of missing peaks ({min_N_missing_peaks}) is greater or equal to the minimum number of peaks ({min_peaks}) -> it may happen that more peaks are hidden than exist")

        self.charges = np.array(list(charge_probs.keys()))
        self.z_probs = np.array(list(charge_probs.values()))

    def _tensorize(self, obj):
        try:
            return torch.tensor(np.array(obj), device=self.device)
        except:
            return [torch.tensor(elem, device=self.device) for elem in obj]
        
    def _pad_tensor(self, tensor: list[torch.tensor]):
        return torch.nn.utils.rnn.pad_sequence(tensor, batch_first=True)
    
    def _get_hidden_masks(self,
                          b_ions_tensor):
        pep_len = len(b_ions_tensor)

        if pep_len >= 3:
            x = np.random.randint(self.min_N_missing_peaks, min(pep_len-2, self.max_N_missing_peaks)+1)
            
            hidden_indices = set(np.random.choice(range(pep_len-2), x, replace=False))
            partition = np.random.choice([0,1], pep_len-2, replace=True)
            indices = [i not in hidden_indices for i in range(pep_len-2)]

            # The additional Trues are for the overlapp in ions
            mask_b = torch.tensor([True] + [f or p==0 for f, p in zip(indices, partition)] + [True]) # T, F
            mask_y = torch.tensor([True] + [f or p==1 for f, p in zip(indices, partition)] + [True]) # F, T

        return mask_b, mask_y

    def get_mzs(self,
                masses: torch.tensor,
                prec_charges: torch.tensor):
        mzs_b = [np.cumsum(residue_masses[residue_masses!=0], dtype=np.float32) / z + HYDROGEN for residue_masses, z in zip(masses, prec_charges)]
        mzs_y = [torch.flip(np.cumsum(torch.flip(residue_masses[residue_masses!=0], dims=[0]), dtype=np.float32), dims=[0]) / z + H2O + HYDROGEN for residue_masses, z in zip(masses, prec_charges)]

        # Hide some peaks
        if self.max_N_missing_peaks > 0:
            masks = [self._get_hidden_masks(b) for b in mzs_b]
            missing_maks = []
            for i, (m_b, m_y) in enumerate(masks):
                mzs_b[i] = mzs_b[i][m_b]
                mzs_y[i] = mzs_y[i][m_y]
                missing_maks += [~(m_b & m_y)] # single mask indicating if a token's peak is hidden

        mzs = [torch.sort(torch.cat([b, y]))[0] for b, y in zip(mzs_b, mzs_y)]

        # Add noise peaks
        max_mzs = [torch.max(mz_tensor, dim=-1)[0] + self.max_mz for mz_tensor in mzs]
        num_noise_peaks = np.random.randint(self.min_N_noise_peaks, self.max_N_noise_peaks, (masses.shape[0],))
        noise_peaks = [np.round(np.random.uniform(0, max_mz, (num,)), decimals=2) for num, max_mz in zip(num_noise_peaks, max_mzs)]
        merged_mzs, noise_masks = zip(*[
            (sorted_combined := np.sort(np.concatenate((mz, noise))).astype(np.float32),
            np.isin(sorted_combined, noise).astype(np.bool_))
            for mz, noise in zip(mzs, noise_peaks)
        ])

        return self._pad_tensor(self._tensorize(merged_mzs)), self._pad_tensor(self._tensorize(noise_masks)), self._pad_tensor(self._tensorize(missing_maks))
    
    def get_intensities(self,
                        mzs: torch.tensor,
                        noise_mask: torch.tensor):
        ints = torch.normal(1, 0.1, mzs.shape, dtype=torch.float32)
        ints[mzs==0] = 0
        ints[noise_mask] = torch.normal(0.4, 0.1, ints[noise_mask].shape, dtype=torch.float32)

        ints = torch.clamp(ints, min=0.0)

        return self._pad_tensor(self._tensorize(ints))
    
    def get_precursor_mzs(self,
                          delta_masses: torch.tensor,
                          prec_charges: torch.tensor):
        masses = torch.sum(delta_masses, dim=-1)
        prec_mzs = masses / prec_charges + HYDROGEN

        return self._tensorize(prec_mzs)
    
    def _get_lengths(self, N: int):
        lengths = np.random.randint(self.min_peaks, self.max_peaks, (N,))
        return lengths
    
    def get_delta_masses(self, N: int):
        lengths = self._get_lengths(N)
        masses = [np.round(np.random.uniform(self.min_mz, self.max_mz, length), 3).astype(np.float32) for length in lengths]
        return self._pad_tensor(self._tensorize(masses))
    
    def get_charges(self, N: int):
        prec_charges = np.random.choice(self.charges, N, replace=True, p=self.z_probs).astype(np.int32)
        return self._tensorize(prec_charges)

class SimulatedLanceContainer(SimulatedSpectraContainer):
    def __init__(self,
                 lance_dataset: AnnotatedSpectrumDataset,
                 device):
        super().__init__(
            lance_dataset.tokenizer,
            min_mz=60,
            max_mz=300,
            min_peaks=5,
            max_peaks=20,
            min_N_noise_peaks=0,
            max_N_noise_peaks=10,
            charge_probs={},
            min_N_missing_peaks=0,
            max_N_missing_peaks=5,
            device=device,
        )
        self.dataset_iter = lance_dataset.__iter__()
        self.dataset = lance_dataset
        self.deMass = lambda sequence: [np.round(self.dataset.tokenizer.masses[aa_idx], decimals=2) for aa_idx in sequence]

    def get_delta_masses(self, N: int):
        self.current_batch = self.dataset_iter.__next__()
        seqs = self.current_batch['seq']
        token_masses = torch.tensor([self.deMass(sequence) for sequence in seqs], device=self.device)

        return self._pad_tensor(self._tensorize(token_masses))
    
    def get_charges(self, N: int):
        return self.current_batch['precursor_charge']
    
    def get_seq(self):
        return self.current_batch['seq']

class SimulatedSetContainer(SimulatedSpectraContainer):
    def __init__(self,
                 tokenizer: PeptideTokenizer,
                 min_peaks: int,
                 max_peaks: int,
                 min_N_noise_peaks: int,
                 max_N_noise_peaks: int,
                 charge_probs: dict,
                 min_N_missing_peaks: int,
                 max_N_missing_peaks: int,
                 mass_set: set,
                 device):
        super().__init__(
            tokenizer,
            min_mz=None,
            max_mz=300,
            min_peaks=min_peaks,
            max_peaks=max_peaks,
            min_N_noise_peaks=min_N_noise_peaks,
            max_N_noise_peaks=max_N_noise_peaks,
            charge_probs=charge_probs,
            min_N_missing_peaks=min_N_missing_peaks,
            max_N_missing_peaks=max_N_missing_peaks,
            device=device,
        )
        self.mass_set = mass_set

    def get_delta_masses(self, N: int):
        lengths = np.random.randint(self.min_peaks, self.max_peaks, (N,))
        masses = [np.round(np.random.choice(list(self.mass_set), length), 3).astype(np.float32) for length in lengths]
        return self._pad_tensor(self._tensorize(masses))

class SimulatedBasicSpectrumContainer(SimulatedSpectraContainer):
    def __init__(self,
                 tokenizer: PeptideTokenizer,
                 min_mz: float,
                 max_mz: float,
                 min_peaks: int,
                 max_peaks: int,
                 min_N_noise_peaks: int,
                 max_N_noise_peaks: int,
                 charge_probs: dict,
                 device,
    ):
        super().__init__(
            tokenizer = tokenizer,
            min_mz = min_mz,
            max_mz = max_mz,
            min_peaks = min_peaks,
            max_peaks = max_peaks,
            min_N_noise_peaks = min_N_noise_peaks,
            max_N_noise_peaks = max_N_noise_peaks,
            charge_probs = charge_probs,
            min_N_missing_peaks = 0,
            max_N_missing_peaks = 0,
            device = device,
        )

    def get_mzs(self,
                masses: torch.tensor,
                prec_charges: torch.tensor):
        mzs = [np.cumsum(residue_masses[residue_masses!=0], dtype=np.float32) / z for residue_masses, z in zip(masses, prec_charges)]
        missing_masks = [np.zeros_like(mz, dtype=np.bool_) for mz in mzs] # No missing peaks

        if self.max_N_noise_peaks > 0:
            max_mzs = [torch.max(mz_tensor, dim=-1)[0] + self.max_mz for mz_tensor in mzs]
            
            num_noise_peaks = np.random.randint(self.min_N_noise_peaks, self.max_N_noise_peaks, (masses.shape[0],))
            
            noise_peaks = [np.round(np.random.uniform(0, max_mz, (num,)), decimals=2) for num, max_mz in zip(num_noise_peaks, max_mzs)]

            merged_mzs, noise_masks = zip(*[
                (sorted_combined := np.sort(np.concatenate((mz, noise))).astype(np.float32),
                np.isin(sorted_combined, noise).astype(np.bool_))
                for mz, noise in zip(mzs, noise_peaks)
            ])
        else:
            merged_mzs = mzs
            noise_masks = [np.zeros_like(mz, dtype=np.bool_) for mz in mzs]

        return self._pad_tensor(self._tensorize(merged_mzs)), self._pad_tensor(self._tensorize(noise_masks)), self._pad_tensor(self._tensorize(missing_masks))

class SimulatedParameterizedSpectrumContainer(SimulatedSpectraContainer):
    def _get_lengths(self, N: int):
        shape, loc, scale = (2.9001868632925243, 5.761051746843277, 3.390570114998205) # fit on MSV-V1

        lengths = gamma.rvs(shape, loc=loc, scale=scale, size=N)
        # Round to the nearest integer
        lengths = np.round(lengths).astype(int)
        # Ensure that the lengths are within the specified range
        lengths = np.clip(lengths, self.min_peaks, self.max_peaks)
        return lengths

class SimulatedSpectrumDataset(IterableDataset):
    def __init__(self,
                 simulation_container: SimulatedSpectraContainer,
                 reverse_peps: bool,
                 batch_size: int,
                 num_peptides: int | None = None):
        super().__init__()
        self.sim_container = simulation_container
        self.reverse = reverse_peps
        self.bs = batch_size
        self.batch_counter = 0
        self.num_peptides = num_peptides

    def __iter__(self):
        self.batch_counter = 0

        return self
    
    def _normalize_spectra(self, mzs: torch.tensor, intensities: torch.tensor):
        # Same way as is done for real spectra
        if mzs.shape[1] > 300:
            top_300_indices = torch.topk(intensities, 300, dim=-1)[1]
            mzs = torch.gather(mzs, 1, top_300_indices)
            intensities = torch.gather(intensities, 1, top_300_indices)

        intensities = torch.sqrt(intensities)
        intensities = intensities * (1 / torch.max(intensities, dim=-1).values).unsqueeze(dim=-1)

        intensities = torch.nn.functional.normalize(intensities, p=2, dim=-1)

        return mzs, intensities

    
    def __next__(self):
        if self.num_peptides is not None and self.batch_counter * self.bs > self.num_peptides:
            raise StopIteration
        
        peak_files = ['Simulated_Data'] * self.bs
        scan_ids = [str(id) for id in list(range(self.bs * self.batch_counter, self.bs * (self.batch_counter + 1)))]
        ms_level = self.sim_container._tensorize(np.full(self.bs, 2, dtype=np.int64))

        delta_masses = self.sim_container.get_delta_masses(self.bs)
        prec_charges = self.sim_container.get_charges(self.bs)
        mzs, noise_mask, missing_mask = self.sim_container.get_mzs(delta_masses, prec_charges)
        prec_mz = self.sim_container.get_precursor_mzs(delta_masses, prec_charges)
        intensities = self.sim_container.get_intensities(mzs, noise_mask)

        # mzs, intensities = self._normalize_spectra(mzs, intensities)

        batch = {
            'peak_file': peak_files,
            'scan_id': scan_ids,
            'ms_level': ms_level,
            'precursor_mz': prec_mz,
            'precursor_charge': prec_charges,
            'mz_array': mzs,
            'intensity_array': intensities,
            'delta_masses': delta_masses,
            'missing_masks': missing_mask,
            'noise_msk': noise_mask,
        }

        try:
            batch |= {'seq': self.sim_container.get_seq()}
        except:
            pass

        self.batch_counter += 1

        return batch

class SplitAwareSpectrumDataset(SpectrumDataset):
    def __init__(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
        batch_size: int,
        split_indices: str,
        shuffle: bool = True,
        path: PathLike | None = None,
        parse_kwargs: dict | None = None,
        **kwargs: dict,
    ) -> None:
        """Initialize a SplitAwareSpectrumDataset by calling super and storing the split."""
        super().__init__(
            spectra,
            batch_size,
            path,
            parse_kwargs,
            **kwargs,
        )
        with open(split_indices, "rb") as f:
            self.original_split = pkl.load(f) # load the np array
            self._split = None
        self._shuffle = shuffle

    @property
    def n_spectra(self) -> int:
        """The number of spectra in the subsetted Lance dataset."""
        return len(self.original_split)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        raise RuntimeError("__getitem__ was called on SplitAwareSpectrumDataset. This is not well defined")
        return self._to_tensor(self.dataset.take([self.original_split[idx]]))

    def __iter__(self):
        if self._shuffle:
            self._split = np.random.permutation(self.original_split)
        else:
            self._split = self.original_split
        self._batch_counter = 0
        return self
    
    def __next__(self):
        if self._batch_counter*self.batch_size >= self.n_spectra:
            raise StopIteration
        else:
            start = self._batch_counter*self.batch_size
            end = (self._batch_counter+1)*self.batch_size
            subset = self._split[start:end]
            self._batch_counter += 1

            batch = self._to_tensor(self.dataset.take(subset).to_batches()[0])
            batch.pop('seq', None) # Delete the seq key -> NOT annotated dataset
            return batch
            # to_batches(self, max_chunksize=None) https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_batches
        

    ### DEBUG IMPLEMENTATIONS ###
    def add_spectra(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
    ):
        raise NotImplementedError
        

class SplitAwareAnnotatedSpectrumDataset(AnnotatedSpectrumDataset):
    def __init__(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
        annotations: str,
        tokenizer: PeptideTokenizer,
        batch_size: int,
        split_indices: np.array,
        shuffle: bool = True,
        path: PathLike = None,
        parse_kwargs: dict | None = None,
        residue_counts_path: str | None = None,
        episodic_distributions: dict | None = None,
        epoch: EpochTracker | None = None,
        invert_episode: bool | None = None,
        **kwargs: dict,
    ) -> None:
        """Initialize an AnnotatedSpectrumDataset by calling super and storing the split."""
        super().__init__(
            spectra,
            annotations,
            tokenizer,
            batch_size,
            path,
            parse_kwargs,
            **kwargs,
        )
        self.epoch = epoch
        self.invert_episode = invert_episode
        self._episodic = False

        with open(split_indices, "rb") as f:
            self.original_split = pkl.load(f) # load the np array
            self._split = None
        self._shuffle = shuffle
        print(f"\nBuilt dataset with shuffle: {shuffle} and {len(self.original_split)} datapoints\n")

        self.deMass = lambda sequence: [self.tokenizer.masses[aa_idx] for aa_idx in sequence]

        if residue_counts_path is not None:
            self._episodic = True
            # This is the train dataset that has acitvated episodic training
            with open(residue_counts_path, "rb") as f:
                self._residue2idx, self._residue_count_mat = pkl.load(f)

            # Check if counts exist for all PSMs in the lance file
            num_total_PSMs = self.dataset.count_rows()
            num_PSMs_count_matrix = self._residue_count_mat.shape[0]
            if num_total_PSMs != num_PSMs_count_matrix:
                raise RuntimeError(f"Dataset creation: The number of total PSMs in the provided lance file ({num_total_PSMs}) does not match the number of PSMs for which a residue occurence has been registered in the ``residue_counts_path`` file ({num_PSMs_count_matrix})")

            target_counts = list(episodic_distributions.values())
            num_episodes = len(target_counts)
            num_unique_tokens = 1 + np.max(np.array(list(self._residue2idx.values())))
            target_distributions = np.zeros((num_episodes, num_unique_tokens))

            # Check if the provided episodic distributions are made up of only tokens that are in the expected set of tokens for the count matrix
            residues_in_count_mat = [list(self._residue2idx.keys())[list(self._residue2idx.values()).index(idx)] for idx in range(num_unique_tokens)] # (duplicate filtering i.e. replacements like I/L)
            residues_in_target_dists = [list(dist.keys()) for dist in target_counts]
            residues_in_target_dists = set([residue for sublist in residues_in_target_dists for residue in sublist])
            if not residues_in_target_dists.issubset(residues_in_count_mat):
                raise RuntimeError(f"Dataset creation: The residues specified in the episodes do not match the residues in the counting matrix.\nEpisodic Tokens:\n{residues_in_target_dists}\nCount Matrix tokens:\n{residues_in_count_mat}")
            
            # The epoch object must exist and give the correct epoch
            if self.epoch is None or self.epoch() != 0:
                raise RuntimeError(f"Dataset creation: The epoch tracking object does either not exist ({self.epoch is None}) or is not set the the correct default value which was {self.epoch()} and expected to be 0")
            
            # The invert_episode should be set to either true or false
            if self.invert_episode is None:
                print("Dataset creation: [Warning] invert_episode was set to None despite the dataset being initialized for episodic training. It will be treated as invert_episode = False i.e. return data following the provided distributions")
                self.invert_episode = False

            for i, distribution in enumerate(target_counts):
                for residue, count in distribution.items():
                    idx = self._residue2idx[residue]
                    target_distributions[i, idx] = count

            self.target_distributions = target_distributions / np.sum(target_distributions, axis=-1).reshape(-1,1)

            # Generate a list of tokens which are not allowed in the respective epoch
            self.black_listed_tokens = []
            for epoch, _ in enumerate(target_counts):
                black_listed_tokens_peptide = [token for token, idx in self._residue2idx.items() if self.target_distributions[epoch][idx] == 0]
                self.black_listed_tokens += [self.tokenizer.tokenize(black_listed_tokens_peptide)] # Getting the indices of all forbidden tokens

            for e in range(len(self.target_distributions)):
                print(f"Epoch {e} contains {self._get_n_spectra_in_epoch(e, self.original_split)} PSMs after filtering for the provided distribution")

    @property
    def n_spectra(self) -> int:
        """The number of spectra in the subsetted Lance dataset."""
        try:
            return len(self._split)
        except Exception:
            return len(self.original_split)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        raise RuntimeError("__getitem__ was called on SplitAwareSpectrumDataset. This is not well defined")
        return self._to_tensor(self.dataset.take([self.original_split[idx]]))

    def __iter__(self):
        if self._shuffle:
            self._split = np.random.permutation(self.original_split)
        else:
            self._split = self.original_split

        if self._episodic:
            self._split = self._split[self._get_split_mask_for_epoch(self.epoch())]

        self._batch_counter = 0
        return self
    
    def __next__(self):
        if self._batch_counter*self.batch_size >= self.n_spectra:
            raise StopIteration
        else:
            start = self._batch_counter*self.batch_size
            end = (self._batch_counter+1)*self.batch_size
            subset = self._split[start:end]
            self._batch_counter += 1

            batch = self._to_tensor(self.dataset.take(subset).to_batches()[0])

            if self.invert_episode:
                batch['mask'] = ~isin(batch['seq'], self.black_listed_tokens[self.epoch()])

            seqs = batch['seq']  if "seq" in batch else None
            token_masses = torch.tensor([self.deMass(sequence) for sequence in seqs]) if seqs is not None else None

            if token_masses is not None:
                batch['delta_masses'] = token_masses

            batch.pop('scans', None) # Delete the scans key -> not part of annotated dataset
            batch.pop('title', None) # Delete the title key -> not part of annotated dataset
            return batch
    
    def _get_split_mask_for_epoch(self, epoch: int):
        split_counts = self._residue_count_mat[self._split, :]

        episodic_black_list_mask = self.target_distributions[epoch] == 0
        if self.invert_episode:
            return np.any(split_counts[:, episodic_black_list_mask] != 0, axis=-1)
        else:
            return np.all(split_counts[:, episodic_black_list_mask] == 0, axis=-1)

    def _get_n_spectra_in_epoch(self, epoch: int, split: np.array):
        tmp = self._split
        self._split = split
        split_mask = self._get_split_mask_for_epoch(epoch)
        self._split = tmp
        return sum(split_mask)

    ### DEBUG IMPLEMENTATIONS ###
    def add_spectra(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
    ):
        raise NotImplementedError