from collections import defaultdict

def punctuate(line):
    if line[-1] not in ['.', '?', '!']:
        if line[-1] == '\'' or line[-1] == '"':
            line = line[:-1] + ' .' + line[-1]
        else:
            line = line + ' .'
    else:
        line = line[:-1] + ' ' + line[-1]
    return line

def edit_distance(word1, word2):
    m, n = len(word1), len(word2) 
    dp = [[0 for x in range(n + 1)] for x in range(m + 1)] 

    for i in range(m + 1): 
        for j in range(n + 1): 
  
            if i == 0: 
                dp[i][j] = j    
            elif j == 0: 
                dp[i][j] = i    
            elif word1[i-1] == word2[j-1]: 
                dp[i][j] = dp[i-1][j-1] 
            else: 
                dp[i][j] = 1 + min(dp[i][j-1], dp[i-1][j], dp[i-1][j-1]) 
  
    return dp[m][n]

def argmin(lst):
    return min(range(len(lst)), key=lambda x: lst[x])

def find_index(context, word):
    tokenized = context.split()
    editdists = [edit_distance(w, word) for w in tokenized]
    
    index = argmin(editdists)
    
    return index, tokenized[index]

def load_whic(dataset = 'train'):
    whic = []
    with open(f"../data/whic/{dataset}.tsv", "r") as f:
        for line in f:
            c1, w1, c2, w2, label = line.strip().split("\t")

            if w1 == 'child' and 'children' in c1:
                w1 = 'children'
            if w2 == 'child' and 'children' in c2:
                w2 = 'children'
            if w1 == 'cry' and 'cries' in c1:
                w1 = 'cries'
            if w2 == 'cry' and 'cries' in c2:
                w2 = 'cries'
            if w1 == 'body' and 'bodies' in c1:
                w1 = 'bodies'
            if w2 == 'body' and 'bodies' in c2:
                w2 = 'bodies'
            
            idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]
            
            context1 = [punctuate(c1), idx1]
            context2 = [punctuate(c2), idx2]
            
            whic.append([context1, context2, label])
            
    return whic

def pairwise_direction(dataset = 'train'):
    whic_temp = []
    with open(f"../data/whic/{dataset}.tsv", "r") as f:
        for line in f:
            c1, w1, c2, w2, label = line.strip().split("\t")

            if w1 == 'child' and 'children' in c1:
                w1 = 'children'
            if w2 == 'child' and 'children' in c2:
                w2 = 'children'
            if w1 == 'cry' and 'cries' in c1:
                w1 = 'cries'
            if w2 == 'cry' and 'cries' in c2:
                w2 = 'cries'
            if w1 == 'body' and 'bodies' in c1:
                w1 = 'bodies'
            if w2 == 'body' and 'bodies' in c2:
                w2 = 'bodies'
            
            idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]
            
            context1 = [punctuate(c1), idx1]
            context2 = [punctuate(c2), idx2]
            
            whic_temp.append([context1, w1, context2, w2, label])
    
    positive_samples = []
    negative_samples = []
    for entry in whic_temp:
        c1, w1, c2, w2, label = entry
        if label == '1':
            positive_samples.append([c1, c2, label])
            negative_samples.append([c2, c1, '0'])
    
    return positive_samples, negative_samples
    

def pairwise_context(dataset = 'train'):
    positive_samples = []
    negative_samples = []
    with open(f"../data/whic/{dataset}.tsv", "r") as f:
        for line in f:
            c1, w1, c2, w2, label = line.strip().split("\t")

            if w1 == 'child' and 'children' in c1:
                w1 = 'children'
            if w2 == 'child' and 'children' in c2:
                w2 = 'children'
            if w1 == 'cry' and 'cries' in c1:
                w1 = 'cries'
            if w2 == 'cry' and 'cries' in c2:
                w2 = 'cries'
            if w1 == 'body' and 'bodies' in c1:
                w1 = 'bodies'
            if w2 == 'body' and 'bodies' in c2:
                w2 = 'bodies'
            
            idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]
            
            context1 = [punctuate(c1), idx1]
            context2 = [punctuate(c2), idx2]
    
            if label == '1':
                positive_samples.append([context1, w1, context2, w2, label])
            if label == '0':
                negative_samples.append([context1, w1, context2, w2, label])
        
    pos = []
    neg = []
    
    word_idx = defaultdict(list)
    
    counter = -1
    for cp1, wp1, cp2, wp2, lp in positive_samples:
        for cn1, wn1, cn2, wn2, ln in negative_samples:
            if wp1 == wn1 and wp2 != wn2 and cp1 == cn1:
                pos.append([cp1, cp2, lp])
                neg.append([cp1, cn2, ln])
                counter+=1
                word_idx[wp1].append(counter)
    return pos, neg, word_idx

