import requests
import json
import re
from SPARQLWrapper import SPARQLWrapper, JSON
import os
import openai
import time
import random
import pandas as pd
import csv
import json
import io

openai.api_key = "Your API Key"


def read_prompt(file):
    with open(file,'r') as f:
        prompt = "".join(f.readlines())
    return prompt

def prompt_text(create_prompt):
    result = openai.Completion.create(
              model="text-davinci-003",
              prompt=create_prompt,
              max_tokens=500,
              temperature=1.0
            )
    result_text = result['choices'][0]['text'].strip()
    return result_text



from typing import List, Tuple

import editdistance
import pandas as pd
from fire import Fire
from nltk import word_tokenize


#for post processing span matching
def match_term_editdistance(term, sent_tokens):
    new_words = []
    for word in word_tokenize(term):
#         print(word)
        
        if word in sent_tokens:
            new_words.append(word)
            continue

        distances = []
        for token in sent_tokens:
            distances.append(editdistance.eval(word, token))
        
#         print(distances)
        smallest_idx = distances.index(min(distances))
        new_words.append(sent_tokens[smallest_idx])
    return " ".join(new_words)


def find_sublist_indices(items: list, query: list):
    indices = []
    length = len(query)
    for i in range(len(items) - length + 1):
        if items[i : i + length] == query:
            indices.append(i)
#     print(indices)
    return indices


def find_span_pair(head, tail, tokens):
    head = match_term_editdistance(head, tokens)
    tail = match_term_editdistance(tail, tokens)
    pair = [],[]
    best = 1e9
    
    h_list = find_sublist_indices(tokens, head.split())
    t_list = find_sublist_indices(tokens, tail.split())
    
    if len(h_list) <= len(t_list):
        for a in h_list:
    #         print(a)
            for c in t_list:
#                 print(c)
                b = a + len(head.split()) - 1
                d = c + len(tail.split()) - 1
    #             print(b)
    #             print(d)
                if a == c:
                    pass
                else:
                    dist = min(abs(b - c), abs(a - d))
        #             print(dist)
                    if dist < best:
                        best = dist
                        pair = list(range(a, b + 1)), list(range(c, d + 1))
    
    else:
        for c in t_list:
            for a in h_list:
#                 print(c)
                b = a + len(head.split()) - 1
                d = c + len(tail.split()) - 1
    #             print(b)
    #             print(d)
                if a == c:
                    pass
                else:
                    dist = min(abs(b - c), abs(a - d))
        #             print(dist)
                    if dist < best:
                        best = dist
                        pair = list(range(a, b + 1)), list(range(c, d + 1))

    return pair


def test_find_span_pair(text,head,tail):
    tokens = word_tokenize(text)

    span_head, span_tail = find_span_pair(head, tail, tokens)
    return span_head, span_tail
