from typing import Iterable, Optional, Union
import torch
import torch.nn as nn

def strat_select_id(layer: torch.tensor, idx: int) -> torch.tensor:
    """
    Selects a specific parameter vector from a given layer parameters as tensor by its index.

    Parameters:
    layer (torch.tensor): The PyTorch layer containing the parameters.
    idx (int): The index of the parameter vector to select.

    Returns:
    torch.Tensor: The selected parameter vector.
    """
    return layer[idx]

def strat_avg_all(layer: torch.tensor) -> torch.tensor:
    """
    Computes the average of all parameter vectors in a given layer parameters as tensor.

    Parameters:
    layer (torch.tensor): The PyTorch layer containing the parameters.

    Returns:
    torch.Tensor: The average of all parameter vectors.
    """
    return layer.mean(0)

def strat_avg_by_ids(layer: torch.tensor, ids: int) -> torch.tensor:
    """
    Computes the average of the parameter vectors for specified indices in a given layer parameters as tensor.

    Parameters:
    layer (torch.tensor): The PyTorch layer containing the parameters.
    ids (Iterable[int]): The indices of the parameter vectors to average.

    Returns:
    torch.Tensor: The average of the parameter vectors for the specified indices.
    """
    return layer[ids].mean(0)

class LifelongLearningEmbedding(nn.Module):
    """
    Lifelong-Learning Embedding layer.

    This class extends `torch.nn.Embedding` with functionality for dynamic,
    lifelong learning: managing a vocabulary map, supporting optional special
    indices (unknown/padding tokens), and enabling later adjustments to the
    vocabulary size or embedding dimension.

    The class behaves like `nn.Embedding`, accepts the same arguments,
    and can be used as a drop-in replacement in existing models.
    """
    def __init__(
            self, 
            vocab_map: Union[dict, set],
            embedding_dim: int, 
            unknown_idx: Optional[Union[bool, int]] = None, 
            padding_idx: Optional[Union[bool, int]] = None,
            **kwargs,
        ) -> None:
        """
        Parameters
        ----------
        vocab_map : Union[dict, set]
            Mapping from tokens to unique indices. If a set is provided,
            it will be converted into a dictionary with automatically
            assigned indices.
        embedding_dim : int
            Dimensionality of the embedding vectors.
        unknown_idx : Union[bool, int], optional, default=None
            Index or flag for an unknown (out-of-vocabulary) token:
            * `True`: add an unknown token at the last available index.
            * `int`: use the specified index for the unknown token.
            * `None` or `False`: disable the unknown token.
        padding_idx : Union[bool, int], optional, default=None
            Index or flag for a padding token:
            * `True`: add a padding token at the last available index.
            * `int`: use the specified index for the padding token.
            * `None` or `False`: disable the padding token.
        **kwargs :
            Additional arguments passed directly to `torch.nn.Embedding`,
            including:
            - max_norm : Optional[float]
            - norm_type : float
            - scale_grad_by_freq : bool
            - sparse : bool
            - _weight : Optional[Tensor]
            - device, dtype

        Attributes
        ----------
        vocab_map : dict
            Mapping from tokens to their integer indices.
        vocab_size : int
            Size of the vocabulary, including unknown and padding tokens.
            This is equivalent to `num_embeddings` in `nn.Embedding`.
        embedding_dim : int
            Dimensionality of the embedding vectors.
        unknown_idx : Optional[int]
            Index of the unknown token if enabled, else None.
        padding_idx : Optional[int]
            Index of the padding token if enabled, else None.
        embedding : nn.Embedding
            The actual embedding layer mapping token indices to vectors.
        kwargs : dict
            Additional arguments passed to `nn.Embedding`.

        Notes
        -----
        - If both `unknown_idx` and `padding_idx` are set to True,
          they are assigned sequentially starting from the current
          vocabulary size.
        - The `padding_idx` is excluded from gradient updates but
          still produces a valid embedding vector.
        - All other features and methods of `nn.Embedding` remain
          available (including `forward`).
        """
        super().__init__()

        self.vocab_map = self.handle_vocab_map(vocab_map) if isinstance(vocab_map, set) else vocab_map
        # not initialization uses the given vocab_map if it is a dict # necessary for loading from checkpoint 
        self.embedding_dim = embedding_dim

        self.vocab_size = len(vocab_map)

        self.unknown_idx = unknown_idx
        self.padding_idx = padding_idx
        self.set_unknown_and_padding_idx(self.vocab_size, self.unknown_idx, self.padding_idx)
        self.kwargs = kwargs

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=self.padding_idx, **self.kwargs)

        # required information for the update embedding
        self._is_last_unknown_idx: bool = False
        self._is_last_padding_idx: bool = False

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """Forward function for the embedding layer.
        :param inputs: A tensor with ids.
        :type inputs: torch.Tensor
        :return: Embedded input with shape (batch_size, embedding_dim).
        :rtype: torch.Tensor
        """
        return self.embedding(inputs)
        
    @torch.no_grad()
    def embed(self, sequence: torch.tensor) -> torch.LongTensor:
        """Embed a given sequence.

        :param sequence: A sequence of ids that should be embedded.
        :type sequence: torch.tensor

        :return: An embedded sequence with the same shape as the input sequence.
        :rtype: torch.LongTensor
        """
        return self.embedding(sequence)
    
    def map_words(self, words: Iterable) -> list:
        """Maps word tokens to the coresponding id for the one hot encoding.

        :param words: Words of a Iterable 
        :type words: Iterable

        :return: A list of ids for the given words
        :rtype: list
        """
        assert isinstance(words, Iterable), f"Parameter words must be an iterable, but type {type(words)} is not iterable."
        return [self.vocab_map.get(word, self.unknown_token if self.unknown_token else len(self.vocab_map)) for word in words if word in self.vocab_map or self.unknown_token]
    
    def handle_vocab_map(self, vocab_map: Union[dict, set]) -> dict:
        """Handles the vocabulary map and assigne fitting ids to the words for one-hot encoding.

        :param vocab_map: The vocabulary map which maps the vocabulary to ids.
        :type vocab_map: Union[dict, set]

        :return: A new vocabulary map with the words and their ids.
        :rtype: dict
        """
        assert isinstance(vocab_map, (dict, set)), f"Parameter vocab_map must be a dict or set, but type {type(vocab_map)} is not."

        if isinstance(vocab_map, dict):
            vocab_map = set(vocab_map.keys())
        return {word: idx for idx, word in enumerate(vocab_map)}

    def set_unknown_and_padding_idx(self, vocab_size: int, unknown_idx: Optional[Union[bool, int]] = None,  padding_idx: Optional[Union[bool, int]] = None) -> None:
        self.unknown_idx = unknown_idx
        self.padding_idx = padding_idx
        if self.unknown_idx:
            # if unknown token given or true it will be part of the input and output
            vocab_size += 1
            # if unknow token is true set it to be the last idx.
            if isinstance(self.unknown_idx, bool) or self._is_last_unknown_idx:
                self.unknown_idx = vocab_size -1
                self._is_last_unknown_idx = True
        if self.padding_idx:
            # if padding token given it will be part of the input but is not part of the output
            vocab_size += 1
            # if padding token is true set it to be the last idx.
            if isinstance(self.padding_idx, bool) or self._is_last_padding_idx:
                self.padding_idx = vocab_size -1
                self._is_last_padding_idx = True
        # edge-case padding idx smaller or equal to unkown idx and unknown idx must be the last
            elif self._is_last_unknown_idx and self.padding_idx <= self.unknown_idx:
                self.unknown_idx += 1
        self.vocab_size = vocab_size
    
    @torch.no_grad()
    def detect_unused_classes(self, input_sample: torch.Tensor = None, threshold: float = 1.0) -> list:
        """Detects unused classes in the embedding layer based on the input sample.

        :param input_sample: A sample of the input data.
        :type input_sample: torch.Tensor
        :param threshold: The threshold for detecting unused classes, defaults to 1.0
        :type threshold: float, optional

        :return: A list of unused class indices.
        :rtype: list
        """
        
        assert isinstance(threshold, float), f"Threshold must be a float, but got {type(threshold)}."
        current_device = next(self.parameters()).device # get current used device

        if input_sample is None:
            # generate input sample from vacab map
            input_sample = torch.tensor(list(self.vocab_map.values()), dtype=torch.long)
            if self.unknown_idx:
                input_sample = torch.cat((input_sample, torch.tensor([self.unknown_idx], dtype=torch.long)), dim=0)
        input_sample = input_sample.to(current_device)
        self.eval()

        threshold = threshold * (1/ self.out_size)  # scale threshold to the output size
        mean_softmax_activation = torch.zeros(self.out_size, device=current_device)
        for i in input_sample:
            # get the output of the model for each possible input
            # Note: The model is in eval mode, so no gradients are computed
            # and the output is a list of tensors for each context position.
            out = self(i.unsqueeze(0))
            mean_softmax_activation += torch.mean(torch.cat([nn.Softmax(dim=-1)(o) for o in out]), dim=0)
        mean_softmax_activation = mean_softmax_activation / len(input_sample)
        return torch.where(mean_softmax_activation < threshold)[0].tolist()  # unused classes
    
    @torch.no_grad()
    def extend_embedding_dim(self, new_dim: int) -> None:
        """Extends the embedding layer with a new dimension.

        :param new_dim: The new dimension to be added.
        :type new_dim: int
        """
        assert new_dim > self.embedding_dim, f"New dimension {new_dim} must be greater than the current embedding dimension {self.embedding_dim}."

        current_device = next(self.parameters()).device # get current used device
        new_embedding = nn.Embedding(self.vocab_size, new_dim, padding_idx=self.padding_idx, **self.kwargs).to(current_device)
        # Set all embedding weights to 0
        with torch.no_grad():
            new_embedding.weight.zero_()
        new_embedding_weights_tensor = new_embedding.state_dict()["weight"]
        for i, weights in enumerate(self.embedding.state_dict()["weight"]):
            new_embedding_weights_tensor[i][:self.embedding_dim] = weights
        
        self.embedding_dim = new_dim
        self.embedding = new_embedding

    @torch.no_grad()
    def update_embedding(self, new_vocab_map: Union[dict, set], strategy: callable = None, strat_params: dict = dict()) -> None:
        new_vocab_map = self.handle_vocab_map(new_vocab_map) # always handle the vocab map for embedding update to make sure ids are consistent to embedding size.
        current_device = next(self.parameters()).device # get current used device
        
        old_unknown_idx = self.unknown_idx # necessary because super().update_embedding() updates the unknown idx

        self.set_unknown_and_padding_idx(len(new_vocab_map), self.unknown_idx, self.padding_idx)
        
        new_embedding = nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=self.padding_idx, **self.kwargs).to(current_device)
        # Shallow copy of the newly created embedding
        new_embedding_weights_tensor = new_embedding.state_dict()["weight"]

        # copy of old weights
        old_embedding_weights_tensor = self.embedding.state_dict()["weight"]

        # Iteration over all items in the new vocabulary.
        # The key represents the word and the value the word index in the new embedding
        for word, word_id in new_vocab_map.items():
            # check if the new word is in the old vocabulary map.
            if word in self.vocab_map:
                
                old_word_id = self.vocab_map.get(word)
                # set the appropiated embedding weights for the new word
                new_embedding_weights_tensor[word_id] = old_embedding_weights_tensor[old_word_id]
            
            # if the word is not in the old vocabulary map...
            else:
                # and if an initialization stratagy is given
                if not strategy is None:
                    new_embedding_weights_tensor[word_id] = strategy(old_embedding_weights_tensor, **strat_params)
            
            # update unknown token if its given.
        if self.unknown_idx:
            # set the appropriated embedding weights for the unknown token
            # Note: The unknown token idx is ALWAYS the same id of the embedding model. 
            new_embedding_weights_tensor[self.unknown_idx] = old_embedding_weights_tensor[old_unknown_idx]
        
        self.vocab_map = new_vocab_map
        self.embedding = new_embedding
    
    @torch.no_grad()
    def remove_unused_classes(self, threshold: float = 0.01) -> None:
        """Removes unused classes from the embedding layer based on the threshold.

        :param threshold: The threshold for detecting unused classes, defaults to 0.01
        :type threshold: float, optional
        """
        assert isinstance(threshold, float), f"Threshold must be a float, but got {type(threshold)}."
        
        unused_classes = self.detect_unused_classes(threshold=threshold)
        
        new_vocabs = self.vocab_map.keys() - set(unused_classes)
        
        # update the embedding with the new vocab map
        self.update_embedding(new_vocabs)
    
class SkipGram(nn.Module):
    
    @classmethod
    def create_from_checkpoint(
        cls, 
        checkpoint_path: str, 
        vocab_map: Union[dict, set], 
        embedding_dim: int, 
        context_size: Optional[int] = 1, 
        sparse: Optional[bool] = False, 
        unknown_idx: Optional[Union[bool, int]] = None, 
        padding_idx: Optional[Union[bool, int]] = None
    ) -> "SkipGram":
        """
        Creates a SkipGram model instance from a saved checkpoint.

        This method initializes a new SkipGram model with the provided configuration
        and restores its state from a checkpoint file. The checkpoint must include the
        model's state_dict under the key `"model"`.

        Parameters:
        ----------
        checkpoint_path : str
            Path to the checkpoint file containing the saved model state.
        vocab_map : dict
            A mapping of tokens to unique indices. This defines the vocabulary to be embedded.
        embedding_dim : int
            The dimension of the embedding vectors that represent tokens in the latent space.
        context_size : Optional[int], default=1
            The number of context tokens to predict for each input token.
        sparse : Optional[bool], default=False
            If True, gradients with respect to the embedding weight matrix will be computed
            as sparse tensors. Note that PyTorch's Distributed Data Parallel (DDP) does not
            support sparse tensors.
        unknown_idx : Optional[Union[bool, int]], default=None
            Index or inclusion flag for an unknown token. See `SkipGram.__init__` for details.
        padding_idx : Optional[Union[bool, int]], default=None
            Index or inclusion flag for a padding token. See `SkipGram.__init__` for details.

        Returns:
        -------
        SkipGram
            An instance of the SkipGram model with the restored state from the checkpoint.

        Notes:
        -----
        - Ensure that the checkpoint file is compatible with the model's architecture and
        parameter settings. Mismatched configurations may cause errors when loading the
        state_dict.
        - The model is restored to CPU by default, regardless of the device used during saving.

        Example:
        -------
        >>> model = SkipGram.create_from_checkpoint(
        ...     "path/to/checkpoint.pth", 
        ...     vocab_map=my_vocab_map, 
        ...     embedding_dim=100, 
        ...     context_size=2, 
        ...     sparse=False, 
        ...     unknown_idx=True, 
        ...     padding_idx=None
        ... )
        >>> print(model)
        """
        # Initialize the SkipGram model with the provided settings
        skipgram = cls(vocab_map, embedding_dim, context_size, unknown_idx, padding_idx, sparse=sparse)
        
        # Load the checkpoint and restore the model state
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        skipgram.load_state_dict(checkpoint["model"])
        
        return skipgram

    def __init__(
            self, 
            vocab_map: Union[dict, set],
            embedding_dim: int, 
            context_size: Optional[int] = 1, 
            sparse: Optional[bool] = False, 
            unknown_idx: Optional[Union[bool, int]] = None, 
            padding_idx: Optional[Union[bool, int]] = None,
        ) -> None:
        """
        Initializes the SkipGram model based on Mikolov et al.'s 2013 paper,
        'Efficient Estimation of Word Representations in Vector Space.'

        This implementation is designed to work with any kind of tokens, such as words in NLP,
        product identifiers, or action labels. The model learns embeddings for tokens and
        predicts context tokens based on a given input token.

        Parameters:
        ----------
        vocab_map: Union[dict, set]
            A mapping of tokens to unique indices. If a set is provided, it will be converted to a dictionary with indices for the one-hot encoding.
        embedding_dim : int
            The dimension of the embedding vectors that represent tokens in the latent space.
        context_size : Optional[int], default=1
            The number of context tokens to predict for each input token. A larger value means
            more context is considered.
        sparse : Optional[bool], default=False
            If True, gradients with respect to the embedding weight matrix will be computed as
            sparse tensors. Note that PyTorch's Distributed Data Parallel (DDP) does not support
            sparse tensors.
        unknown_idx : Optional[Union[bool, int]], default=None
            Index or inclusion flag for an unknown token. If an unknown token is used, it provides
            a representation for out-of-vocabulary tokens. 
            Note: If unknown token is part of the vocabulary in vocab_map, please do not set unknown_idx.
            If set to:
            - `True`: Adds an unknown token with the last available index.
            - `int`: Uses the specified index for the unknown token.
            - `None` or `False`: Disables the unknown token.
        padding_idx : Optional[Union[bool, int]], default=None
            Index or inclusion flag for a padding token. Padding tokens are included in the input
            vocabulary but excluded from the output predictions. If set to:
            - `True`: Adds a padding token with the last available index.
            - `int`: Uses the specified index for the padding token.
            - `None` or `False`: Disables the padding token.

        Attributes:
        ----------
        vocab_size : int
            The size of the input vocabulary, including any unknown or padding tokens.
        out_size : int
            The size of the output vocabulary, including the unknown token if applicable.
        embedding : nn.Embedding
            The embedding layer that maps token indices to embedding vectors.
        linear_out : nn.ModuleList
            A list of linear layers, each mapping embedding vectors to output scores for
            context token prediction.

        Notes:
        -----
        - If both `unknown_idx` and `padding_idx` are `True`, their indices will be assigned
        sequentially, starting with the vocabulary size.
        - The `padding_idx` is excluded from the output predictions but included in the embeddings.
        """
        super().__init__()

        self.embedding = LifelongLearningEmbedding(vocab_map, embedding_dim, unknown_idx, padding_idx, sparse=sparse)

        self.context_size = context_size
        self.out_size = self.embedding.vocab_size - (1 if self.embedding.padding_idx and self.embedding.unknown_idx else 0)

        self.linear_out = nn.ModuleList([nn.Linear(self.embedding.embedding_dim, self.out_size) for _ in range(self.context_size)])

    def forward(self, inputs: torch.Tensor) -> list:
        """Forward function for the SkipGram model.

        :param inputs: A tensor with ids.
        :type inputs: torch.Tensor

        :return: Linear layer output in a list of context size.
        :rtype: list
        """
        embeddings = self.embedding(inputs)
        return [linear(embeddings) for linear in self.linear_out]
    
    @torch.no_grad()
    def extend_embedding(self, new_dim: int) -> None:
        """Extends the embedding layer with a new dimension.

        :param new_dim: The new dimension to be added.
        :type new_dim: int
        """
        current_device = next(self.parameters()).device # get current used device
        new_linear_out = nn.ModuleList([nn.Linear(new_dim, self.out_size).to(current_device) for _ in range(self.context_size)])

        new_linear_weights_tensor = list(new_linear_out.state_dict().values())
        old_linear_weights_tensor = list(self.linear_out.state_dict().values()) 
        
        for i, weights in enumerate(old_linear_weights_tensor):
            if i % 2 == 0:
                new_linear_weights_tensor[i][:,:self.embedding.embedding_dim] = weights
            else:
                new_linear_weights_tensor[i].copy_(weights)
        self.embedding.extend_embedding_dim(new_dim)
        
        self.linear_out = new_linear_out
    
    @torch.no_grad()
    def update_embedding(self, new_vocab_map: dict, strategy: callable = None, strat_params: dict = dict()) -> None:
        """Updates the embedding weights for a new vocabulary dictionary. 
        Necessary if new vocabulary is added or removed from the relevant corpus.
        This function is necessary for incremental/online learning for embeddings.
        If no strategy function is given the weigths are randomly initialised.

        :param new_vocab_map: The new vocabulary map which maps the vocabulary to ids.
        :type new_vocab_map: dict
        """
        current_device = next(self.parameters()).device # get current used device

        old_unknown_idx = self.embedding.unknown_idx # necessary because super().update_embedding() updates the unknown idx
        old_vocab_map = self.embedding.vocab_map # necessary because super().update_embedding() updates the vocab map

        self.embedding.update_embedding(new_vocab_map, strategy, strat_params)
        # set the new output size
        new_out_size = self.embedding.vocab_size - (1 if self.embedding.padding_idx and self.embedding.unknown_idx else 0)
        unknown_idx_adder = 0
        if self.embedding.padding_idx and self.embedding._is_last_unknown_idx and self.embedding.padding_idx <= self.embedding.unknown_idx:
            unknown_idx_adder = 1

        # create new  linear layers that matches the new vocabulary size
        new_linear_out = nn.ModuleList([nn.Linear(self.embedding.embedding_dim, new_out_size).to(current_device) for _ in range(self.context_size)])

        # Shallow copy of the newly created linear layers.
        new_linear_weights_tensor = list(new_linear_out.state_dict().values())

        # Copy old weights
        # linear layer weights and bias in following order [linear0.weight, linear0.bias, linear1.weight,...]
        old_linear_weights_tensor = list(self.linear_out.state_dict().values()) 
        
        # Iteration over all items in the new vocabulary.
        # The key represents the word and the value the word index in the new embedding
        for word, word_id in self.embedding.vocab_map.items():
            # check if the new word is in the old vocabulary map.
            if word in old_vocab_map:
                old_word_id = old_vocab_map.get(word)
                # for all linear layers set the appropriate embedding weights for the new word
                for i, layer in enumerate(old_linear_weights_tensor):
                    new_linear_weights_tensor[i][word_id] = layer[old_word_id]
            
            # if the word is not in the old vocabulary map...
            else:
                # and if an initialization stratagy is given
                if not strategy is None:
                    
                    for i, layer in enumerate(old_linear_weights_tensor):
                        new_linear_weights_tensor[i][word_id] = strategy(layer, **strat_params)
        
        # update unknown token if its given.
        if self.embedding.unknown_idx:
            # for all linear layers set the appropiated embedding weights for the unknown token
            for i, layer in enumerate(old_linear_weights_tensor):
                # note the linear layer has never padding tokens, therefore for the edge-case in which the padding id is smaller the unknown id we need to remove the adder
                new_linear_weights_tensor[i][self.embedding.unknown_idx - unknown_idx_adder] = layer[old_unknown_idx - unknown_idx_adder]
        # Note: padding idx doesn't need to be updated because it is always 0 in the embedding weigths and not part of the linear output
        
        # update the class attributes
        self.out_size = new_out_size
        self.linear_out = new_linear_out
    
    @torch.no_grad()
    def remove_unused_classes(self, threshold: float = 0.01) -> None:
        """Removes unused classes from the embedding layer based on the threshold.

        :param threshold: The threshold for detecting unused classes, defaults to 0.01
        :type threshold: float, optional
        """
        assert isinstance(threshold, float), f"Threshold must be a float, but got {type(threshold)}."
        self.embedding.remove_unused_classes(threshold=threshold)

    def fit(self, dataloader: Iterable, optimizers: list, device = torch.device("cpu")) -> float:
        """Trains the SkipGram model with a given dataloader.

        :param dataloader: A arbitrary dataloader that iterates over the batches.
            The batch consists of features and targets.
            Features is of type torch.Tensor and have the shape (batch_size).
            Targets is of type torch.Tensor and have the shape (batch_size, contex_size)
        :type dataloader: Iterable
        :param optimizers: List of optimizer to optimize the weights.
            If sparse embedding is used a sparse optimizer have to used additionally.
        :type optimizer: list
        :param device: Device where to compute the training process, defaults to torch.device("cpu")
        :type device: torch.device, optional

        :return: Mean loss over all batches.
        :rtype: float
        """
        train_loss = []

        # Change the model into train mode if necessary.
        self.train()
        # Move the model to the device which should compute the training
        self.to(device)

        # iterates over the batches provided by the datalaoder.
        for batch in dataloader:
            features, targets = batch[0], batch[1]

            # move features and targets to the used device
            features = features.long().to(device)
            targets = targets.long().to(device)

            # reset gradients
            self.zero_grad()

            # make a prediction
            out = self(features)
            
            # compute the loss for each output
            # CrossEntropyLoss is used for multiclass classification tasks 
            # and is a combination of softmax activation and negative log-likelihood loss
            loss = sum([nn.CrossEntropyLoss()(o, targets[:,i]) for i, o in enumerate(out)])

            # backpropagation of the loss
            loss.backward()

            # update the weigths
            for optimizer in optimizers:
                optimizer.step()

            # add the loss of this batch 
            train_loss.append(loss.item())
        
        return sum(train_loss)/len(train_loss)