import numpy as np 
import spacy 
import re 
from spacy.tokenizer import Tokenizer
import os  
from tqdm import tqdm   
from string import punctuation 
import string
import yaml
from utils.tools import prepad
from text import _clean_text

nlp = spacy.load('en_core_web_sm')
nlp.tokenizer = Tokenizer(nlp.vocab, token_match=re.compile(r'\S+').match)

def get_nps(doc):
    nps_ids = [] 
    for chunk in doc.noun_chunks:
        id = [int(token.i) for token in chunk] 
        nps_ids.append([id[0], id[-1]])
    end = len([token for token in doc]) - 1
    return nps_ids, end  

def context_words(nps_ids, end):
    '''given ps , return left context and right context'''
    context = [] 
    for nps_id in nps_ids:
        num_chunk = nps_id[-1] - nps_id[0] + 1 
        contexts = [-1,-1,-1,-1]
        # Left context , right context 
        if  nps_id[0] - num_chunk >= 0:
            contexts[:2] = [nps_id[0]-num_chunk, nps_id[0]-1]
        if  nps_id[1] + num_chunk <= end: 
            contexts[2:] = [nps_id[1]+1, nps_id[1]+num_chunk]
        context.append(contexts)

    return context
    
def get_idx_list(text, phones):
    word_start , word_end = [] , [] 
    phone_lengths = []

    word_phones = phones.split(" ")
    word_start, word_end = [0]*len(word_phones) , [len(phones)-1]*len(word_phones)
    length = 0 
    for i,phone in enumerate(word_phones):
        if i==0:
            length += len(word_phones[i])
            word_end[i] = length - 1
            length += 1   
        elif i== len(word_phones)-1:
            word_start[i] = length  
            length += len(word_phones[i])
        else:
            word_start[i] = length 
            length += len(word_phones[i])
            word_end[i] = length - 1 
            length += 1 
    return word_start, word_end 

def wid2pid(idxs, word_start, word_end):
    array = [] 
    for idx in idxs:
        if -1 in idx:
            array.append([-1,-1])
        else:
            array.append([word_start[idx[0]], word_end[idx[1]]])  
    return np.array(array) 

def preprocess_english(text, phones, np_ids, end):
    cw_ids = context_words(np_ids, end)
    lcw_ids = [[cw_id[0], cw_id[1]] for cw_id in cw_ids]
    rcw_ids = [[cw_id[2], cw_id[3]] for cw_id in cw_ids]

    word_start, word_end = get_idx_list(text, phones)
    np_ids, lcw_ids, rcw_ids = wid2pid(np_ids, word_start, word_end), wid2pid(lcw_ids, word_start, word_end), wid2pid(rcw_ids, word_start, word_end)
    cw_ids =  np.concatenate([lcw_ids, rcw_ids], axis=1)
    return np_ids, cw_ids

def pr_text(text, translator):
    text_r = text.translate(translator)
    text_r = text_r.split(" ")
    print(text_r)
    if 'nt' in text_r:
        text_r =  text_r.remove('nt')
    text_r = ' '.join(text_r)

    return text_r

def get_len(text):
    t = text.split(" ")
    return len(t)

class DPP_preprocessor:
    def __init__(self, config):
        self.config = config 
        self.in_dir = "./raw_data/LJSpeech"
        self.out_dir = config['path']['preprocessed_path']

    def build_from_path(self):
        os.makedirs((os.path.join(self.out_dir, "ids")), exist_ok=True)

        print("Processing Data ...")
        with open("filelist/ljs_audio_train_filelist.txt", 'r') as f:
            with open("filelist/ljs_audio_train_filelist1.txt", 'w') as g:
                lines = f.readlines()
                translator = str.maketrans('', '', string.punctuation)
                for line in tqdm(lines):
                    line_ = line.strip("\n").split("|")
                    name, text, phone  = line_[0].split(".")[0], line_[1], line_[2]
                    text = prepad(text, item="")
                    text_r = text.translate(translator)     
                    doc = nlp(text_r)
                    np_ids, end = get_nps(doc)   
                    if np_ids:
                        np_ids, cw_ids = preprocess_english(text, phone, np_ids, end)
                        np.save(os.path.join(self.out_dir, "ids", "{}-npids-{}".format("LJSpeech", name)), np_ids)
                        np.save(os.path.join(self.out_dir, "ids", "{}-cwids-{}".format("LJSpeech", name)), cw_ids)
                        line_[1] = text
                        line = "|".join(line_)+"\n"
                        g.write(line)
                    else:
                        continue    
                    
            print("Done!")


if __name__ == "__main__":
    #textgrid = tgt.io.read_textgrid("LJ001-0030.TextGrid", include_empty_intervals=True)
    
    textgrid = tgt.io.read_textgrid("LJ040-0014.TextGrid", include_empty_intervals=True)
    word_tier = textgrid.get_tier_by_name('words')
    phone_tier = textgrid.get_tier_by_name('phones')

    word_dict = get_dict_from_word_tier(word_tier)
    start_dict, end_dict = get_dict_from_phone_tier(phone_tier)

    #print(start_dict)
    text, offset = get_text_from_word_tier(word_tier)
    doc = nlp(text)
    np_ids, cw_ids = preprocess_english(doc, offset, word_dict, start_dict, end_dict)

    #print(np_ids)

    #print(cw_ids)
    