import itertools
from uu import decode

from sklearn.compose import ColumnTransformer
from sqlalchemy import desc
from tblib import Traceback

from .enc_dec import EncDec
import numpy as np
import copy

from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, OrdinalEncoder
from typing import List, Union
import spacy

from ...utils import Concept

__all__ = ["TextEnc"]




class TextEnc(EncDec):
    """
    It provides an interface to access One Hot enconding (https://en.wikipedia.org/wiki/One-hot) functions. 
    It relies on OneHotEncoder class from sklearn
    """
    nlp = spacy.load('en_core_web_sm')


    def tokenize(self, text):
        if self.nlp is None:
            return np.array(text.split(), dtype='<80U')
        else:
            return np.array([token.text for token in self.nlp(text)], dtype='<80U')

    def __init__(self,base_str:str, descriptor: dict):
        super().__init__(descriptor)
        self.type='text'
        self.encoded_descriptor = None
        if isinstance(base_str, str):
            self.base_str = self.tokenize(base_str)
        else:
            raise ValueError('base_str must be a string or a list')


    def encode(self, X: Union[np.ndarray,List[str]]):
        """
        It applies the encoder to the input features

        :param [Numpy array] x: Array to encode
        :return [Numpy array]: Encoded array
        """
        
        # print("encode")
        ## print call history
        # import traceback
        # traceback.print_stack()

        X = np.stack([self.tokenize(x) for x in X])
        encoded = []
        for x in X:
            encoded.append((x==self.base_str).astype(int))
        encoded = np.stack(encoded)
        return encoded

    def decode(self, Z: np.ndarray):
        """
        Decode the array staring from the original descriptor

        :param [Numpy array] x: Array to decode
        :return [Numpy array]: Decoded array
        """
        # print("decode")
        ## print call history
        # import traceback
        # traceback.print_stack()
        # decoded = self.encoder.inverse_transform(Z)
        decoded = np.repeat(self.base_str[np.newaxis,:],Z.shape[0],axis=0)
        # decoded = [self.base_str.copy() for _ in range(Z.shape[0])]
        # for i in range(Z.shape[0]):
        decoded[Z==0] = '[UNK]' #TODO
        decoded = [' '.join(x) for x in decoded]
        # decoded = np.array(decoded)
        # print('decoded inverted transformer', decoded)
        # print('encoded feature scikit', self.encoder.named_transformers_.get('categorical').categories_)

        return decoded

    def decode_target_class(self, Z: np.ndarray):
        """
        Decode the target class

        :param [Numpy array] x: Array containing the target class values to be decoded
        """
        return Z.copy()
        # return self.target_encoder.inverse_transform(Z)

    def encode_target_class(self, X: np.ndarray):
        """
        Encode the target class
        :param X:
        :return:
        """
        return X.copy()
        # return self.target_encoder.transform(X)