import re
import string
import numpy as np
import spacy

from concept_processing.io import get_datapoint_iterator
from concept_processing.pam import count_datapoints_in_each_feature

def concept_dict_to_list(concept_dict):
    """
    parameters
    ----------
    concept_dict - dictionary mapping from concept (e.g. strs) to index (int)
        where indices are contiguous and starting from zero.
        
    returns
    -------
    concepts - a list of concepts where concepts[i] is key k such that
        concept_dict[k] = i
    """
    reverse_dict = { i:s for s, i in concept_dict.items() }
    concepts = [ reverse_dict[i] for i in range(len(concept_dict))]
    return concepts


def display_most_frequent_concepts(concepts, pam, K=None):
    concept_counts = count_datapoints_in_each_feature(pam)
    if K is None:
        K=len(concept_counts)
    freq_order_concepts = np.argsort(concept_counts)[::-1]
    print(f"Top {K} concepts:")
    #print(freq_order_concepts[:K])
    print("\tRank\tID\tCount\tConcept")
    for i, k in enumerate(freq_order_concepts[:K]):
        print(f"\t{i}\t{k}\t{concept_counts[k]} : {concepts[k]}")
        
def build_embedding_matrix(concepts, transformer_model, include_articles=False):
    # build embedding matrix
    C = len(concepts)
    embeds = np.empty((C, 768))
    for i,concept in enumerate(concepts):
        if not include_articles:
            concept = remove_articles(concept)
        embeds[i,:] = transformer_model.encode(concept)
    return embeds
    
def remove_articles(text):
    """
    We remove articles (and other troublesome words) that interfere with the
    clustering.
    """
    articles = set(['a', 'an', 'the'])
    rest = [ word for word in text.split() if not word in articles]
    if rest[0].lower() == 'then':
        rest.pop(0)
    return ' '.join(rest)
    
    
def complete_pam_with_substrings(
        path, grouped_pam, grouped_concept_ids, grouped_concepts, nlp=None, rows_to_remove=None):
    """
    Uses the substring lookup method, on string concepts discovered via some
    process to check for all occurences of those substrings in the dataset.
    Then updates the pam to include all such occurrences.
    
    parameters
    ----------
    grouped_pam - the pam NxK where N is the number of datapoints and K is the
        number of pruned grouped concepts.
    grouped_concept_ids - list of lists, the ith list (of K lists) is the cluster with the final id i.
        The list contains the original indices of the child concepts. 
    grouped_concepts (elsewhere called concept_group_labels) - list of lists,
        the ith list is the cluster with the final id i.
        The list contains the original lowercase/no-punct strings of the child concepts. 
    """
    completed_pam = np.copy(grouped_pam)
    if nlp is None:
        nlp = spacy.load("en_core_web_lg")
    ids = []
    all_bag_of_new_features = []
    for i, datapoint in enumerate(get_datapoint_iterator(path, rows_to_remove=rows_to_remove)):
        bagofgroupids = []
        id_, label, text = datapoint
        ids.append(id_)
        text = text.replace('�','')
        text = re.sub(' +', ' ',text)
        doc = nlp(text)
        for sent in doc.sents:
            sentstr = sent.text.translate(str.maketrans('', '', string.punctuation)).lower()
            thesegroupids = get_concept_indices_by_substring(sentstr, grouped_concept_ids, grouped_concepts)
            bagofgroupids.extend(thesegroupids)
        setofnewfeatures = np.unique(bagofgroupids).astype(int)
        completed_pam[i,setofnewfeatures] = 1
    return completed_pam
            
def get_concept_indices_by_substring(sentstr, grouped_concept_ids, grouped_concepts):
    """
    Helper method for complete_pam_with_substrings
    """
    thesegroupids = []
    for groupid, (rawids, rawstrings) in enumerate(zip(grouped_concept_ids, grouped_concepts)):
        for rawid, rawstring in zip(rawids, rawstrings):
            if rawstring in sentstr:
                thesegroupids.append(groupid)
    return thesegroupids    
    
    

