import pickle as pkl
import re
from typing import List
import os

def load_pkl(file_loc):
    with open(file_loc, "rb") as f:
        return pkl.load(f)

def save_pkl(obj, file_loc):
    with open(file_loc, "wb") as f:
        pkl.dump(obj, f)

def load_txt(file_loc):
    with open(file_loc, "r") as f:
        return f.read()
    
def remove_non_letters(sentence:str):
  '''
  desc: removes all non letters from the str
  '''
  return re.sub(r'\W+', '', sentence)

def indices_where_in_range(indices, value_range_start, value_range_stop):
   return [arg for (arg, value) in enumerate(indices) if value_range_start <= value < value_range_stop ]

def verification_format(transcript_str):
    s = remove_non_letters(transcript_str)
    return s.upper()

def check_if_any_shared_indicies(list_1:List[int], list_2:List[int]):
    '''
    list1 : list of token indicies
    list2 : list of token indicies
    desc: checks if the first list has any elements that are contained in the second
    '''
    for token_index in list_1:
        if token_index in list_2:
            return True
    return False
    
def story_words_to_transcript_words_map(story_word_times, transcript):
    #mild cleaning of transcript, drop newlines and drop multiple consecutive spaces
    transcript = transcript.replace("\n", " ")
    transcript = re.sub(r"\s+", " ", transcript)
    
    transcript_words = transcript.split(" ")
    story_words = [x["word"] for x in story_word_times]
    story_words_map = []
    
    transcript_verification_str = ""
    story_verification_str = ""
    
    transcript_index = 0
    words_map = []
    
    for i, story_word in enumerate(story_words):
        transcript_verification_str = verification_format("".join(transcript_words[:transcript_index+1]))
        story_verification_str = verification_format("".join(story_words[:i+1]))
        if transcript_verification_str[:len(story_verification_str)] == story_verification_str:
            story_words_map.append([transcript_index])
        else:
            words_map = []
            for j in range(transcript_index + 1, len(transcript_words)):
                words_map.append(j)
                transcript_verification_str = verification_format("".join(transcript_words[:j+1]))
                if transcript_verification_str[:len(story_verification_str)] == story_verification_str:
                    story_words_map.append(words_map)
                    transcript_index = j
                    break
    return story_words_map, story_words, transcript_words

def remove_duplicate_words(story_words_map, story_words, transcript_words):
    out_words = []
    shared_word_index = None
    shared_word_string = ""
    for i in range(0, len(story_words)-1):
        shared_indices = check_if_any_shared_indicies(story_words_map[i], story_words_map[i+1])
        if shared_indices:
            tmp_shared_word_index = story_words_map[i][-1]
            if not tmp_shared_word_index == shared_word_index:
                shared_word_string = transcript_words[story_words_map[i][-1]]
                shared_word_index = tmp_shared_word_index
            #pop off the part that covers the first word and leave the rest
            for char_index in range(len(shared_word_string)):
                word_string = ""
                for word_index in story_words_map[i]:
                    if word_index == shared_word_index:
                        word_string += shared_word_string[:char_index+1] + " "
                    else:
                        word_string += transcript_words[word_index] + " "
                #if i is in, we found the right one
                if story_words[i] in verification_format(word_string):
                    shared_word_string = shared_word_string[char_index+1:]
                    transcript_words[shared_word_index] = shared_word_string
                    #transcript_words[]
                    break
            out_words.append(word_string)
        else:
            out_words.append(" ".join([transcript_words[x] for x in story_words_map[i]])) 
    out_words.append(" ".join([transcript_words[x] for x in story_words_map[-1]]))
    return out_words
    
def words_to_transcript(story, dataset_loc = "./data"):
    save_folder = f"{dataset_loc}/words_and_times_transcripts"
    os.makedirs(save_folder, exist_ok=True)
    story_word_times = load_pkl(f"{dataset_loc}/clean_words_and_times/{story}.pkl")
    transcript = load_txt(f"{dataset_loc}/transcripts/{story}.txt")
    story_words_map, story_words, transcript_words = story_words_to_transcript_words_map(story_word_times = story_word_times,
                                    transcript = transcript)
    transcript = remove_duplicate_words(story_words_map, story_words, transcript_words)
    for i, story_word_time in enumerate(story_word_times):
        story_word_time["transcript"] = transcript[i]
    save_pkl(story_word_times, save_folder + "/" + f"{story}.pkl")
    return story_word_times