from depthcharge.data import (
                                AnnotatedSpectrumDataset,
                                SpectrumDataset
)
from depthcharge.tokenizers import PeptideTokenizer
from depthcharge.constants import H2O, HYDROGEN

from collections.abc import Iterable
from os import PathLike
from ..utils import EpochTracker
import polars as pl
import numpy as np
from typing import Any
import pickle as pkl
from torch import isin
from torch.utils.data.dataset import IterableDataset
import torch


class SimulatedSpectraContainer:
    def __init__(self,
                 tokenizer: PeptideTokenizer,
                 min_mz: float,
                 max_mz: float,
                 min_peaks: int,
                 max_peaks: int,
                 min_N_noise_peaks: int,
                 max_N_noise_peaks: int,
                 min_N_missing_peaks: int,
                 max_N_missing_peaks: int,
                 charge_probs: dict,
                 device,):
        self.device = device
        self.min_mz = min_mz
        self.max_mz = max_mz
        self.min_peaks = min_peaks
        self.max_peaks = max_peaks
        self.min_N_noise_peaks = min_N_noise_peaks
        self.max_N_noise_peaks = max_N_noise_peaks

        self.min_N_missing_peaks = min_N_missing_peaks
        self.max_N_missing_peaks = max_N_missing_peaks

        self.charges = np.array(list(charge_probs.keys()))
        self.z_probs = np.array(list(charge_probs.values()))

    def _tensorize(self, obj):
        try:
            return torch.tensor(np.array(obj), device=self.device)
        except:
            return [torch.tensor(elem, device=self.device) for elem in obj]
        
    def _pad_tensor(self, tensor: list[torch.tensor]):
        return torch.nn.utils.rnn.pad_sequence(tensor, batch_first=True)
    
    # def get_mzs(self,
    #             masses: torch.tensor,
    #             prec_charges: torch.tensor):
    #     mzs = [np.cumsum(residue_masses[residue_masses!=0], dtype=np.float32) / z for residue_masses, z in zip(masses, prec_charges)]

    #     max_mzs = [torch.max(mz_tensor, dim=-1)[0] + self.max_mz for mz_tensor in mzs]
    #     num_noise_peaks = np.random.randint(self.min_N_noise_peaks, self.max_N_noise_peaks, (masses.shape[0],))
        
    #     noise_peaks = [np.round(np.random.uniform(0, max_mz, (num,)), decimals=2) for num, max_mz in zip(num_noise_peaks, max_mzs)]

    #     merged_mzs, noise_masks = zip(*[
    #         (sorted_combined := np.sort(np.concatenate((mz, noise))).astype(np.float32),
    #         np.isin(sorted_combined, noise).astype(np.bool_))
    #         for mz, noise in zip(mzs, noise_peaks)
    #     ])

    #     return self._pad_tensor(self._tensorize(merged_mzs)), self._pad_tensor(self._tensorize(noise_masks))

    def _get_hidden_masks(self,
                          b_ions_tensor):
        pep_len = len(b_ions_tensor)

        if pep_len >= 3:
            x = np.random.randint(self.min_N_missing_peaks, min(pep_len-2, self.max_N_missing_peaks)+1)
            
            hidden_indices = set(np.random.choice(range(pep_len-2), x, replace=False))
            partition = np.random.choice([0,1], pep_len-2, replace=True)
            indices = [i not in hidden_indices for i in range(pep_len-2)]

            # The additional Trues are for the overlapp in ions
            mask_b = torch.tensor([True] + [f or p==0 for f, p in zip(indices, partition)] + [True]) # T, F
            mask_y = torch.tensor([True] + [f or p==1 for f, p in zip(indices, partition)] + [True]) # F, T

        return mask_b, mask_y

    def get_mzs(self,
                masses: torch.tensor,
                prec_charges: torch.tensor):
        mzs_b = [np.cumsum(residue_masses[residue_masses!=0], dtype=np.float32) / z + HYDROGEN for residue_masses, z in zip(masses, prec_charges)]
        mzs_y = [torch.flip(np.cumsum(torch.flip(residue_masses[residue_masses!=0], dims=[0]), dtype=np.float32), dims=[0]) / z + H2O + HYDROGEN for residue_masses, z in zip(masses, prec_charges)]

        # Hide some peaks
        if self.max_N_missing_peaks > 0:
            masks = [self._get_hidden_masks(b) for b in mzs_b]
            missing_maks = []
            for i, (m_b, m_y) in enumerate(masks):
                mzs_b[i] = mzs_b[i][m_b]
                mzs_y[i] = mzs_y[i][m_y]
                missing_maks += [~(m_b & m_y)] # single mask indicating if a token's peak is hidden

        mzs = [torch.sort(torch.cat([b, y]))[0] for b, y in zip(mzs_b, mzs_y)]

        # Add noise peaks
        max_mzs = [torch.max(mz_tensor, dim=-1)[0] + self.max_mz for mz_tensor in mzs]
        num_noise_peaks = np.random.randint(self.min_N_noise_peaks, self.max_N_noise_peaks, (masses.shape[0],))
        noise_peaks = [np.round(np.random.uniform(0, max_mz, (num,)), decimals=2) for num, max_mz in zip(num_noise_peaks, max_mzs)]
        merged_mzs, noise_masks = zip(*[
            (sorted_combined := np.sort(np.concatenate((mz, noise))).astype(np.float32),
            np.isin(sorted_combined, noise).astype(np.bool_))
            for mz, noise in zip(mzs, noise_peaks)
        ])

        return self._pad_tensor(self._tensorize(merged_mzs)), self._pad_tensor(self._tensorize(noise_masks)), self._pad_tensor(self._tensorize(missing_maks))
    
    def get_intensities(self,
                        mzs: torch.tensor,
                        noise_mask: torch.tensor):
        ints = torch.normal(1, 0.1, mzs.shape, dtype=torch.float32)
        ints[mzs==0] = 0
        ints[noise_mask] = torch.normal(0.4, 0.1, ints[noise_mask].shape, dtype=torch.float32)

        return self._pad_tensor(self._tensorize(ints))
    
    def get_precursor_mzs(self,
                          delta_masses: torch.tensor,
                          prec_charges: torch.tensor):
        masses = torch.sum(delta_masses, dim=-1)
        prec_mzs = masses / prec_charges + HYDROGEN

        return self._tensorize(prec_mzs)
    
    def get_delta_masses(self, N: int):
        lengths = np.random.randint(self.min_peaks, self.max_peaks, (N,))
        masses = [np.round(np.random.uniform(self.min_mz, self.max_mz, length), 3).astype(np.float32) for length in lengths]
        return self._pad_tensor(self._tensorize(masses))
    
    def get_charges(self, N: int):
        prec_charges = np.random.choice(self.charges, N, replace=True, p=self.z_probs).astype(np.int32)
        return self._tensorize(prec_charges)

class SimulatedLanceContainer(SimulatedSpectraContainer):
    def __init__(self,
                 lance_dataset: AnnotatedSpectrumDataset,
                 device):
        super().__init__(
            lance_dataset.tokenizer,
            min_mz=60,
            max_mz=300,
            min_peaks=5,
            max_peaks=20,
            min_N_noise_peaks=0,
            max_N_noise_peaks=10,
            charge_probs={},
            device=device,
        )
        self.dataset_iter = lance_dataset.__iter__()
        self.dataset = lance_dataset
        self.deMass = lambda sequence: [np.round(self.dataset.tokenizer.masses[aa_idx], decimals=2) for aa_idx in sequence]

    def get_delta_masses(self, N: int):
        self.current_batch = self.dataset_iter.__next__()
        seqs = self.current_batch['seq']
        token_masses = torch.tensor([self.deMass(sequence) for sequence in seqs], device=self.device)

        return self._pad_tensor(self._tensorize(token_masses))
    
    def get_charges(self, N: int):
        return self.current_batch['precursor_charge']
    
    def get_seq(self):
        return self.current_batch['seq']


class SimulatedSpectrumDataset(IterableDataset):
    def __init__(self,
                 simulation_container: SimulatedSpectraContainer,
                 reverse_peps: bool,
                 batch_size: int,
                 num_peptides: int | None = None):
        super().__init__()
        self.sim_container = simulation_container
        self.reverse = reverse_peps
        self.bs = batch_size
        self.batch_counter = 0
        self.num_peptides = num_peptides

    def __iter__(self):
        self.batch_counter = 0

        return self
    
    def __next__(self):
        if self.num_peptides is not None and self.batch_counter * self.bs > self.num_peptides:
            raise StopIteration
        
        peak_files = ['Simulated_Data'] * self.bs
        scan_ids = [str(id) for id in list(range(self.bs * self.batch_counter, self.bs * (self.batch_counter + 1)))]
        ms_level = self.sim_container._tensorize(np.full(self.bs, 2, dtype=np.int64))

        delta_masses = self.sim_container.get_delta_masses(self.bs)
        prec_charges = self.sim_container.get_charges(self.bs)
        mzs, noise_mask, missing_mask = self.sim_container.get_mzs(delta_masses, prec_charges)
        prec_mz = self.sim_container.get_precursor_mzs(delta_masses, prec_charges)
        intensities = self.sim_container.get_intensities(mzs, noise_mask)    

        batch = {
            'peak_file': peak_files,
            'scan_id': scan_ids,
            'ms_level': ms_level,
            'precursor_mz': prec_mz,
            'precursor_charge': prec_charges,
            'mz_array': mzs,
            'intensity_array': intensities,
            'delta_masses': delta_masses,
            'noise_mask': noise_mask,
            'missing_mask': missing_mask,
        }

        try:
            batch |= {'seq': self.sim_container.get_seq()}
        except:
            pass

        self.batch_counter += 1

        return batch

class SplitAwareSpectrumDataset(SpectrumDataset):
    def __init__(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
        batch_size: int,
        split_indices: str,
        shuffle: bool = True,
        path: PathLike | None = None,
        parse_kwargs: dict | None = None,
        **kwargs: dict,
    ) -> None:
        """Initialize a SplitAwareSpectrumDataset by calling super and storing the split."""
        super().__init__(
            spectra,
            batch_size,
            path,
            parse_kwargs,
            **kwargs,
        )
        with open(split_indices, "rb") as f:
            self.original_split = pkl.load(f) # load the np array
            self._split = None
        self._shuffle = shuffle

    @property
    def n_spectra(self) -> int:
        """The number of spectra in the subsetted Lance dataset."""
        return len(self.original_split)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        raise RuntimeError("__getitem__ was called on SplitAwareSpectrumDataset. This is not well defined")
        return self._to_tensor(self.dataset.take([self.original_split[idx]]))

    def __iter__(self):
        if self._shuffle:
            self._split = np.random.permutation(self.original_split)
        else:
            self._split = self.original_split
        self._batch_counter = 0
        return self
    
    def __next__(self):
        if self._batch_counter*self.batch_size >= self.n_spectra:
            raise StopIteration
        else:
            start = self._batch_counter*self.batch_size
            end = (self._batch_counter+1)*self.batch_size
            subset = self._split[start:end]
            self._batch_counter += 1

            batch = self._to_tensor(self.dataset.take(subset).to_batches()[0])
            batch.pop('seq', None) # Delete the seq key -> NOT annotated dataset
            return batch
            # to_batches(self, max_chunksize=None) https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_batches
        

    ### DEBUG IMPLEMENTATIONS ###
    def add_spectra(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
    ):
        raise NotImplementedError
        

class SplitAwareAnnotatedSpectrumDataset(AnnotatedSpectrumDataset):
    def __init__(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
        annotations: str,
        tokenizer: PeptideTokenizer,
        batch_size: int,
        split_indices: np.array,
        shuffle: bool = True,
        path: PathLike = None,
        parse_kwargs: dict | None = None,
        residue_counts_path: str | None = None,
        episodic_distributions: dict | None = None,
        epoch: EpochTracker | None = None,
        invert_episode: bool | None = None,
        **kwargs: dict,
    ) -> None:
        """Initialize an AnnotatedSpectrumDataset by calling super and storing the split."""
        super().__init__(
            spectra,
            annotations,
            tokenizer,
            batch_size,
            path,
            parse_kwargs,
            **kwargs,
        )
        self.epoch = epoch
        self.invert_episode = invert_episode
        self._episodic = False

        with open(split_indices, "rb") as f:
            self.original_split = pkl.load(f) # load the np array
            self._split = None
        self._shuffle = shuffle
        print(f"\nBuilt dataset with shuffle: {shuffle} and {len(self.original_split)} datapoints\n")

        if residue_counts_path is not None:
            self._episodic = True
            # This is the train dataset that has acitvated episodic training
            with open(residue_counts_path, "rb") as f:
                self._residue2idx, self._residue_count_mat = pkl.load(f)

            # Check if counts exist for all PSMs in the lance file
            num_total_PSMs = self.dataset.count_rows()
            num_PSMs_count_matrix = self._residue_count_mat.shape[0]
            if num_total_PSMs != num_PSMs_count_matrix:
                raise RuntimeError(f"Dataset creation: The number of total PSMs in the provided lance file ({num_total_PSMs}) does not match the number of PSMs for which a residue occurence has been registered in the ``residue_counts_path`` file ({num_PSMs_count_matrix})")

            target_counts = list(episodic_distributions.values())
            num_episodes = len(target_counts)
            num_unique_tokens = 1 + np.max(np.array(list(self._residue2idx.values())))
            target_distributions = np.zeros((num_episodes, num_unique_tokens))

            # Check if the provided episodic distributions are made up of only tokens that are in the expected set of tokens for the count matrix
            residues_in_count_mat = [list(self._residue2idx.keys())[list(self._residue2idx.values()).index(idx)] for idx in range(num_unique_tokens)] # (duplicate filtering i.e. replacements like I/L)
            residues_in_target_dists = [list(dist.keys()) for dist in target_counts]
            residues_in_target_dists = set([residue for sublist in residues_in_target_dists for residue in sublist])
            if not residues_in_target_dists.issubset(residues_in_count_mat):
                raise RuntimeError(f"Dataset creation: The residues specified in the episodes do not match the residues in the counting matrix.\nEpisodic Tokens:\n{residues_in_target_dists}\nCount Matrix tokens:\n{residues_in_count_mat}")
            
            # The epoch object must exist and give the correct epoch
            if self.epoch is None or self.epoch() != 0:
                raise RuntimeError(f"Dataset creation: The epoch tracking object does either not exist ({self.epoch is None}) or is not set the the correct default value which was {self.epoch()} and expected to be 0")
            
            # The invert_episode should be set to either true or false
            if self.invert_episode is None:
                print("Dataset creation: [Warning] invert_episode was set to None despite the dataset being initialized for episodic training. It will be treated as invert_episode = False i.e. return data following the provided distributions")
                self.invert_episode = False

            for i, distribution in enumerate(target_counts):
                for residue, count in distribution.items():
                    idx = self._residue2idx[residue]
                    target_distributions[i, idx] = count

            self.target_distributions = target_distributions / np.sum(target_distributions, axis=-1).reshape(-1,1)

            # Generate a list of tokens which are not allowed in the respective epoch
            self.black_listed_tokens = []
            for epoch, _ in enumerate(target_counts):
                black_listed_tokens_peptide = [token for token, idx in self._residue2idx.items() if self.target_distributions[epoch][idx] == 0]
                self.black_listed_tokens += [self.tokenizer.tokenize(black_listed_tokens_peptide)] # Getting the indices of all forbidden tokens

            for e in range(len(self.target_distributions)):
                print(f"Epoch {e} contains {self._get_n_spectra_in_epoch(e, self.original_split)} PSMs after filtering for the provided distribution")

    @property
    def n_spectra(self) -> int:
        """The number of spectra in the subsetted Lance dataset."""
        try:
            return len(self._split)
        except Exception:
            return len(self.original_split)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        raise RuntimeError("__getitem__ was called on SplitAwareSpectrumDataset. This is not well defined")
        return self._to_tensor(self.dataset.take([self.original_split[idx]]))

    def __iter__(self):
        if self._shuffle:
            self._split = np.random.permutation(self.original_split)
        else:
            self._split = self.original_split

        if self._episodic:
            self._split = self._split[self._get_split_mask_for_epoch(self.epoch())]

        self._batch_counter = 0
        return self
    
    def __next__(self):
        if self._batch_counter*self.batch_size >= self.n_spectra:
            raise StopIteration
        else:
            start = self._batch_counter*self.batch_size
            end = (self._batch_counter+1)*self.batch_size
            subset = self._split[start:end]
            self._batch_counter += 1

            batch = self._to_tensor(self.dataset.take(subset).to_batches()[0])

            if self.invert_episode:
                batch['mask'] = ~isin(batch['seq'], self.black_listed_tokens[self.epoch()])

            batch.pop('scans', None) # Delete the scans key -> not part of annotated dataset
            batch.pop('title', None) # Delete the title key -> not part of annotated dataset
            return batch
    
    def _get_split_mask_for_epoch(self, epoch: int):
        split_counts = self._residue_count_mat[self._split, :]

        episodic_black_list_mask = self.target_distributions[epoch] == 0
        if self.invert_episode:
            return np.any(split_counts[:, episodic_black_list_mask] != 0, axis=-1)
        else:
            return np.all(split_counts[:, episodic_black_list_mask] == 0, axis=-1)

    def _get_n_spectra_in_epoch(self, epoch: int, split: np.array):
        tmp = self._split
        self._split = split
        split_mask = self._get_split_mask_for_epoch(epoch)
        self._split = tmp
        return sum(split_mask)

    ### DEBUG IMPLEMENTATIONS ###
    def add_spectra(
        self,
        spectra: pl.DataFrame | PathLike | Iterable[PathLike],
    ):
        raise NotImplementedError