import numpy as np 
import spacy 
import os  
from tqdm import tqdm   
from string import punctuation 
import re 
from spacy.tokenizer import Tokenizer
import string
import yaml
from text import cleaned_text_to_sequence, _clean_text
import torch 
from utils.tools import intersperse, prepad

nlp = spacy.load('en_core_web_sm')
nlp.tokenizer = Tokenizer(nlp.vocab, token_match=re.compile(r'\S+').match)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_nps(doc):
    nps_ids = [] 
    for chunk in doc.noun_chunks:
        id = [int(token.i) for token in chunk] 
        nps_ids.append([id[0], id[-1]])
    end = len([token for token in doc]) - 1
    return nps_ids, end  

def context_words(nps_ids, end):
    '''given ps , return left context and right context'''
    context = [] 
    for nps_id in nps_ids:
        num_chunk = nps_id[-1] - nps_id[0] + 1 
        contexts = [-1,-1,-1,-1]
        # Left context , right context 
        if  nps_id[0] - num_chunk >= 0:
            contexts[:2] = [nps_id[0]-num_chunk, nps_id[0]-1]
        if  nps_id[1] + num_chunk <= end: 
            contexts[2:] = [nps_id[1]+1, nps_id[1]+num_chunk]
        context.append(contexts)

    return context
    
def get_idx_list(text, phones):
    word_start , word_end = [] , [] 
    phone_lengths = []

    word_phones = phones.split(" ")
    word_start, word_end = [0]*len(word_phones) , [len(phones)-1]*len(word_phones)
    length = 0 
    for i,phone in enumerate(word_phones):
        if i==0:
            length += len(word_phones[i])
            word_end[i] = length - 1
            length += 1   
        elif i== len(word_phones)-1:
            word_start[i] = length  
            length += len(word_phones[i])
        else:
            word_start[i] = length 
            length += len(word_phones[i])
            word_end[i] = length - 1 
            length += 1 
    return word_start, word_end 

def wid2pid(idxs, word_start, word_end):
    array = [] 
    for idx in idxs:
        if -1 in idx:
            array.append([-1,-1])
        else:
            array.append([word_start[idx[0]], word_end[idx[1]]])  
    return np.array(array) 

def preprocess_english(text, phones, np_ids, end):
    cw_ids = context_words(np_ids, end)
    lcw_ids = [[cw_id[0], cw_id[1]] for cw_id in cw_ids]
    rcw_ids = [[cw_id[2], cw_id[3]] for cw_id in cw_ids]

    word_start, word_end = get_idx_list(text, phones)
    np_ids, lcw_ids, rcw_ids = wid2pid(np_ids, word_start, word_end), wid2pid(lcw_ids, word_start, word_end), wid2pid(rcw_ids, word_start, word_end)
    cw_ids =  np.concatenate([lcw_ids, rcw_ids], axis=1)
    return np_ids, cw_ids

def make_sample_batch(text, preprocess_config):
    # Output = (ids, raw_texts, speakers, texts, src_lens, max_src_len, npids, cwids)
    translator = str.maketrans('', '', string.punctuation)
    text = prepad(text, item="")
    text_r = text.translate(translator)
    doc = nlp(text_r)
    np_ids, end = get_nps(doc)
    phones = _clean_text(prepad(text), ["english_cleaners2"])
    texts = cleaned_text_to_sequence(phones)
    texts = np.array([intersperse(texts)])
    if np_ids:
        np_ids, cw_ids = preprocess_english(text, phones, np_ids, end)

        raw_texts, ids, speakers, text_lens = [text], np.array(["sample"]), np.array([1]), np.array([texts.shape[1]]) 
        batchs = (ids, raw_texts, speakers, texts, text_lens, max(text_lens), np.expand_dims(np_ids, axis=0), np.expand_dims(cw_ids, axis=0))
        return batchs
    else:
        return None  

def sample_batch_preprocess(text, text_lens, np_ids, cw_ids, half=False):
    np_ids = np_ids.transpose(0,1)
    cw_ids = cw_ids.transpose(0,1)
    if half:
        np_ids = np_ids[::2]
        cw_ids = cw_ids[::2]

    np_nums = [1] * np_ids.size(0)

    text = text.expand(np_ids.size(0),-1)
    text_lens = text_lens.expand(np_ids.size(0))

    return text, text_lens, np_ids, cw_ids, np_nums 

if __name__ == '__main__':
    from utils.dpp_tools import to_device
    from utils.dpp_tools import get_random_phrase 

    preprocess_config = yaml.load(open("config/LJSpeech/preprocess.yaml", 'r'), Loader=yaml.FullLoader)
    #text = "Probability smoothing is a language modeling technique that assigns some non-zero probability to events that were unseen in the training data."
    text = "Consequently, although there never can be more than fifteen, there may be only fourteen, or thirteen, or twelve."
    batchs = make_sample_batch(text, preprocess_config) 
    #batch = to_device(batchs, device)
    #print(batch[-2], batch[-1])
    #sample_batch_preprocess(batch[3], batch[4], batch[-2], batch[-1])