# python 3.8
import pandas as pd
import numpy as np
import math
from typing import Union

def remove_transaction_events(data: pd.DataFrame, purchase_event: str, verbose: bool = True) -> pd.DataFrame:
    """_summary_

    :param data: _description_
    :type data: pd.DataFrame
    :param purchase_event: _description_
    :type purchase_event: str
    :param verbose: _description_, defaults to True
    :type verbose: bool, optional
    :return: _description_
    :rtype: pd.DataFrame
    """
    len_data_old = len(data)
    data = data[data.event_type != purchase_event].reset_index(drop=True)
    if verbose:
        print(f"Remaining events: {len(data)}. {len_data_old - len(data)} of {len_data_old} events removed!")
    return data

def remove_sessions_by_length(data: pd.DataFrame, min_len: Union[int,float,None] = None, max_len: Union[int,float,None] = None, verbose: bool = True) -> pd.DataFrame:
    """_summary_

    :param data: _description_
    :type data: pd.DataFrame
    :param min_len: _description_, defaults to None
    :type min_len: Union[int,float,None], optional
    :param max_len: _description_, defaults to None
    :type max_len: Union[int,float,None], optional
    :param verbose: _description_, defaults to True
    :type verbose: bool, optional
    :return: _description_
    :rtype: pd.DataFrame
    """
    len_data_old = len(data)
    if isinstance(min_len, int):
        data = data[data.len > min_len].reset_index(drop=True)
    elif isinstance(min_len, float):
        if verbose:
            print(f"Sequences with less than {data.len.quantile(min_len)} ({(min_len*100):.2f}% quantile) events removed.")
        data = data[data.len > data.len.quantile(min_len)].reset_index(drop=True)

    if isinstance(max_len, int):
        data = data[data.len < max_len].reset_index(drop=True)
    elif isinstance(max_len, float):
        if verbose:
            print(f"Sequences with more than {data.len.quantile(max_len)} ({(max_len*100):.2f}% quantile) events removed.")
        data = data[data.len < data.len.quantile(max_len)].reset_index(drop=True)

    if verbose:
        print(f"Remaining sessions: {len(data)}. {len_data_old - len(data)} of {len_data_old} sequences removed!")
    
    return data

def create_ngram_dataset(data: pd.DataFrame, context: int, add_event_type: bool = True, verbose: bool = True) -> dict:
    """_summary_

    :param data: _description_
    :type data: pd.DataFrame
    :param context: _description_
    :type context: int
    :param add_event_type: _description_, defaults to True
    :type add_event_type: bool, optional
    :param verbose: _description_, defaults to True
    :type verbose: bool, optional
    :return: _description_
    :rtype: dict
    """
    def create_action_n_grams(actions: list, n: int) -> tuple:
        """_summary_

        :param actions: _description_
        :type actions: list
        :param n: _description_
        :type n: int
        :return: _description_
        :rtype: tuple
        """
        n_grams = set()
        actions = ["START"]*n + list(actions) + ["END"]*n
        for i in range(len(actions[n:-n])):
            X = actions[i+n]
            y = tuple([actions[i+n+x] for x in list(range(int(-n/2), 0)) + list(range(1, math.ceil(n/2) + 1))])
            n_grams.add((X,y))
        return n_grams

    embedding_touchpoints = data.touchpoint.tolist()
    if add_event_type:
        # create new touchpoints event_type:touchpoint
        embedding_touchpoints = [[":".join(map(str, i)) for i in zip(seq.event_type, seq.touchpoint)] for seq in data.itertuples()]
    
    # create trigrams for embedding training
    ngrams = {ngram for seq in embedding_touchpoints for ngram in create_action_n_grams(seq, context)}
    ngrams = [(gram[0], list(gram[1])) for gram in ngrams]

    # create a vocabulary map that maps embedding touchpoints to ids. Note tokens are descended sorted by frequancy except START and END token.
    vocab_map = {token: i for i, token in enumerate(pd.Series([gram[0] for gram in ngrams]).value_counts().index)}
    tmp = len(vocab_map)
    vocab_map.update({"START": tmp, "END": tmp+1})
    
    if verbose:
        print(
            f"""{len(ngrams)} ngrams created.\n{len(vocab_map)} different tokens exist for embedding Training.\n
            """
        )
    return {"ngrams": ngrams, "vocab_map": vocab_map}