import os
import sys, getopt
import socket
import numpy as np
import itertools
import matplotlib.pyplot as plt
from predict import predict_lr
from myanchor.anchor_text import TextGenerator
from myanchor.anchor_text import SentencePerturber

def getarg(argv):
    anchorpath = ""
    limepath = ""
    outputpath = ""
    try:
        opts, args = getopt.getopt(argv, "", ["anchorpath=", "save=", "limepath"])
    except getopt.GetoptError:
        raise getopt.GetoptError
    for opt, arg in opts:
        if opt == "--anchorpath":
            anchorpath = arg
        elif opt == "--limepath":
            limepath = arg
        elif opt == "--save":
            outputpath = arg
    return anchorpath, limepath, outputpath


def coverage(words, rules,anchors):
    tot = set()
    inrule = set()
    inanchor = set()
    sp = SentencePerturber(words, TextGenerator(), onepass=True)

    def fill(words):
        ret = words.copy()
        rs = sp.probs(' '.join(words))
        masked = np.where(np.array(words) == sp.mask)[0]
        for x, (w, p) in zip(masked, rs):
            ret[x] = np.random.choice(w, p=p)
        return ret

    def fit(now):
        for rule in rules:
            wordx, wordy, dist = rule
            posx = -1
            posy = -1
            for i, x in enumerate(now):
                if x == wordx:
                    posx = i
                if x == wordy:
                    posy = i
            if dist == -1:
                dist -= 1
            if not (posx != -1 and posy != -1 and posy - posx > dist):
                return False
        return True

    def fit_anchor(now):
        if len(now) != len(words):
            return False
        for word in anchors:
            if word not in now:
                return False
            for i,x in enumerate(now):
                if x == word:
                    if words[i] != x:
                        return False
                    break
        return True


    def sample():
        text = words.copy()
        t = list(range(len(text)))
        part = np.random.choice(list(range(len(text))), 2)
        part = sorted(part)
        np.random.shuffle(t)
        # print(part)

        for i in range(part[0]):
            text[i] = ''
        for i in range(part[0], part[1]):
            text[i] = sp.mask
        x = np.random.choice(len(words) - 1)
        text[x], text[x + 1] = text[x + 1], text[x]
        text = [x for x in text if x != '']
        text = fill(text)

        return text

    for i in range(10000):
        seq = sample()
        tot.add(' '.join(seq))
        if fit(seq):
            inrule.add(' '.join(seq))
        if fit_anchor(seq):
            inanchor.add(' '.join(seq))

    cover = [len(inanchor)/len(tot),len(inrule)/len(tot)]
    ori = predict_lr([' '.join(words)])[0]
    if len(inrule)==0:
        precision = -1
    else:
        myres = predict_lr(list(inrule))
        anchores = predict_lr(list(inanchor))
        precision = [sum(anchores)/len(inanchor),sum(myres)/len(inrule)]
        if ori == 0:
            precision = [1-precision[0], 1-precision[1]]
    return ori,cover,precision


results = 0
tot_results = 0
def process(filename):
    global results
    global tot_results
    print("process "+filename)
    if "txt" not in filename:
        return

    bestrule = ""
    bestcov = 0
    with open(os.path.join(anchorpath, filename),"r") as f:
        words = f.readline()
        words = words.replace(', ', ',')
        words = words.replace(' ', ',')
        # print(words)
        words = eval(words)
        anchors = eval(f.readline())
        anchor_precision = eval(f.readline())
        for i,x in enumerate(anchor_precision):
            if x >= 0.9:
                anchors = anchors[:i+1]
                anchor_precision = anchor_precision[:i+1]
                break
        rule = f.readline()
        rule = eval(rule)
        rule = [eval(x) for x in rule]
        res,cover,precision = coverage(words, rule,anchors)
        tot_results += res
        with open(os.path.join(outpath, filename),"w") as fw:
            fw.write("anchor: "+str(anchors)+'\n')
            fw.write("anchor_dist: "+str(rule)+'\n')
            fw.write("result: "+str(res)+'\n')
            fw.write("coverage: "+str(cover)+'\n')
            fw.write("precision: "+str(precision)+'\n')
            print("wrote")
        if precision == -1:
            return None
        else:
            results += res
            return cover,precision


if __name__ == "__main__":
    totanchor = 0
    totanchordist = 0
    anchorprecision = 0
    anchordistprecision = 0
    cnt = 0
    anchorpath, limepath, outpath = getarg(sys.argv[1:])
    if not os.path.exists(outpath):
        os.mkdir(outpath)
    files = os.listdir(anchorpath)
    for filename in files:
        res = process(filename)
        if res is not None:
            cover, precision = res
            totanchor += cover[0]
            totanchordist += cover[1]
            anchorprecision += precision[0]
            anchordistprecision += precision[1]
            cnt += 1
    totanchor /= cnt
    totanchordist /= cnt
    anchorprecision /= cnt
    anchordistprecision /= cnt
    with open("average_result.txt","w") as f:
        f.write("\tanchor\tanchor+dist\n")
        f.write("coverage: %f %f\n" % (totanchor, totanchordist))
        f.write("precision: %f %f\n" % (anchorprecision, anchordistprecision))
        f.write("number of positive output: %d %d\n" % (results,tot_results))
