from transformers import T5ForConditionalGeneration, T5TokenizerFast as T5Tokenizer
from fuzzywuzzy import fuzz
from fuzzywuzzy import process
from tqdm import tqdm
from sklearn.utils import shuffle
import string



#internal imports
from src.utils.data_generation.pretrained.model.trainer import *
from src.core.configuration.datagen_conf import *
from src.utils.data_generation import csr
from src.core.interface.data_generation import format_artificial_data

def get_attr_names(schema_name):
    data = csr.format_csr(schema_name, "small", -1)
    return data


def remove_punctuation(input):
    return ''.join([char for char in input if char not in string.punctuation])

    
    
def generate_utterance(utterance, keywords, key_slot_maps):
    # get conll and squad converter
    annotated = utterance.split()

    for maps in key_slot_maps:
        key, val = maps
        if val == "":
            continue
        tokens = val.split()
        length = len(tokens)
        key_tokens = annotated
        counter = 0
        while counter + length <= len(key_tokens):
            candidate = key_tokens[counter : (counter + length)]
            counter +=1
            ration = fuzz.ratio(val, remove_punctuation(' '.join(candidate)))
            if ration > 80:
                converted_val = ' '.join(candidate).replace(" ", "_").upper()
                new_list = []
                i = 0
                while i  < len(annotated):
                    if i == annotated.index(candidate[0]):
                        new_list.append(f'vids-{key}({converted_val})')
                        i += length
                    else:
                        new_list.append(annotated[i])
                        i += 1
                    
                annotated = new_list
                break
           
    return ' '.join(annotated)

    
    
    
def split_list(lst, ratio):
    length = len(lst)
    split_index = int(length * ratio)
    return lst[:split_index], lst[split_index:]

def generate_dataset(schema_name, model_type='base', dataset_size = 1000):
    
    
    print("Reading values from file ...\n")
    keyword_list = csr.get_from_raw_csr(schema_name)
    keyword_list = shuffle(keyword_list)
    utterances, annotated_strings = [], []
    
    keyword_list = keyword_list[:dataset_size]
    print("Finished Loading from file.\nStart processing ...\n")
    
    vanilla_list, finetuned_list = split_list(keyword_list, 0.4)
    combined_model_data, schema_based_data = split_list(finetuned_list, 0.5)
    
    vanilla_model_name = 'models/language_model/t5-base'
    
    combined_model_name = f'models/language_model/combined/t5-base-{model_type}-per'
    
    schematized_model_name = f'models/language_model/{schema_name}/t5-base-{model_type}-per'
    
    vanilla_module = load_module(vanilla_model_name)
    combined_module = load_module(combined_model_name)
    schematized_module = load_module(schematized_model_name)
        
    print(f"producing traiing examples for the schema: {schema_name} and model type: {model_type}\n\n")
    vanilla_uttr, vanilla_annotated_string = get_utterances(vanilla_module, vanilla_list)
    combined_uttr, combined_annotated_string = get_utterances(combined_module, combined_model_data)
    schematised_uttr, schematised_annotated_string = get_utterances(schematized_module, schema_based_data)
    
    utterances.extend(vanilla_uttr)
    utterances.extend(combined_uttr)
    utterances.extend(schematised_uttr)
    
    annotated_strings.extend(vanilla_annotated_string)
    annotated_strings.extend(combined_annotated_string)
    annotated_strings.extend(schematised_annotated_string)
    
    
    
    file_name = f"lm-{model_type}-per.txt"
    conll_path = f"{ARTIFICIAL_DATA_LOC}/ner/{schema_name}/raw/{file_name}" 
    squad_path = f"{ARTIFICIAL_DATA_LOC}/squad/{schema_name}/raw/{file_name}"
    csr_uttr_path = f"{ARTIFICIAL_DATA_LOC}/csr/{schema_name}/raw/{file_name}"
    
    conll_writer = open(conll_path, "w", encoding="utf8")
    squad_writer = open(squad_path, "w", encoding="utf8")
    csr_writer = open(csr_uttr_path, "w", encoding="utf8")
    for annot in annotated_strings:
        conll_writer.write(annot+"\n")
        squad_writer.write(annot+"\n")
    conll_writer.close()
    squad_writer.close()
    
    for uttr in utterances:
        csr_writer.write(uttr+"\n")
    csr_writer.close()
    
    
    format_artificial_data(schema_name, fold_num=file_name, format="conll")
    format_artificial_data(schema_name, fold_num=file_name, format="squad")
    
    
def load_module(model_name):
    tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
    model = T5ForConditionalGeneration.from_pretrained(f"{model_name}", return_dict=True)
    
    
    module = trainer(
            tokenizer=tokenizer, 
            model=model, 
            train_df=None, 
            test_df=None, 
            val_df=None, 
            source_max_token_len=512, 
            target_max_token_len=512, 
            batch_size=4, 
            max_epochs=2, 
            outputdir="saved_models"
            )
    
    return module
    
    
    
def get_utterances(module, keyword_list):
    utterances, annotated_strings = [], []
    for attr in tqdm(keyword_list):
        keywords = attr['concepts']
        key_slot_map = attr['type_maps']
        keywords = ' '.join(keywords)
        prediction = module.predict(keywords=keywords)
        
        annotated_utter = generate_utterance(prediction, keywords, key_slot_map)
        utterances.append(prediction)
        annotated_strings.append(annotated_utter)
        
    return utterances, annotated_strings
        
    
    
        
        
schema = 'online_delivary' 
model_type = '0.5'       
generate_dataset(schema, model_type, dataset_size=50000)
