import os
import sys
import re
import string
import tempfile
from tqdm import tqdm
from subprocess import run, check_output, CalledProcessError, STDOUT, PIPE, Popen, TimeoutExpired
import shutil, time, random
import json
import jsonlines

def readTestCaseDict(datasetType):
    path = f"./dataset/{datasetType}/testcases_all.json"
    f = open(path)
    TCdict = json.load(f)
    f.close()
    for k in TCdict:
        TCdict[k] = TCdict[k][:5]  #considering 5 test-cases, at most
    return TCdict

def readIDlist(datasetType):
    path = f'./dataset/{datasetType}/test.{datasetType}.id'
    with open(path) as f:
        idList = f.readlines()
    idList = [id.strip().split("_")[0] for id in idList]
    return idList

def savePredictions(progList, path):
    filePath = os.path.join(path, "pred" + str(int(round(time.time() * 1000))) + ".jsonl")
    predList = []
    for p in progList:
        predList.append({"predC++": p})
    with jsonlines.open(filePath, 'w') as writer:
        writer.write_all(predList)

def makeSubDir(writeDir):
    subfolderExists = True
    while subfolderExists:
        time_rand_str = str(int(round(time.time() * 1000))) + "_" + str(random.randint(1000,9999))
        folderNm = os.path.join(writeDir, time_rand_str)
        subfolderExists = os.path.exists(folderNm)
        if not subfolderExists:
            os.makedirs(folderNm, exist_ok = False)
    return folderNm

def floatReplace_func(match):
    match = match.group()
    return "{:.1f}".format(float(match))

def checkOutputsSame(out1, out2):
    #CHECK CASE OF RANDOM
    out1 = out1.replace('\x00', '')
    out2 = out2.replace('\x00', '')
    if (out1 is None) or (out2 is None):
        return False
    if (len(out1.strip())==0) and (len(out2.strip())==0):
        return False
    if "".join(out1.split()).strip() == "".join(out2.split()).strip():  
        return True

    #replaces all floating points or integers --> uniform representation i.e. 1 digit after .
    out1 = re.sub(r'[-+]?[\d]*[\.][\d]+|[-+]?[\d]+', floatReplace_func, out1)
    #removes all punctuations
    out1 = out1.translate(str.maketrans('', '', string.punctuation))
    out1 = out1.replace("▁", " ")

    #replaces all floating points or integers --> uniform representation i.e. 1 digit after .
    out2 = re.sub(r'[-+]?[\d]*[\.][\d]+|[-+]?[\d]+', floatReplace_func, out2)
    #removes all punctuations
    out2 = out2.translate(str.maketrans('', '', string.punctuation))
    out2 = out2.replace("▁", " ")

    #removes all whitespaces
    out1List = out1.strip().split()
    out2List = out2.strip().split()
    #changes to lowercase
    out1 = "".join(out1List).lower()
    out2 = "".join(out2List).lower()
    print ("out1,out2", out1, out2, flush = True)
    return out1 == out2

def check_CompAccRunEqAcc_cpp(programList, writeDir, datasetType):
    genericCodePrefix =\
    '''#include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <ctime>
    #include <cstdlib>
    #include <algorithm>
    #include <vector>
    #include <string>
    #include <map>
    #include <unordered_map>
    #include <set>
    #include <unordered_set>
    #include <deque>
    #include <random>
    #include <stack>
    #include <queue>
    #include <tuple>
    #include <list>
    #include <climits>
    #include <cassert>
    using namespace std;
    '''
    #datasetType = pseudocode-cpp or SHORTpseudocode-cpp
    #TODO:uncomment------------
    programList = [re.sub(r'#include\s*?<.*?>\s*?', "", p) for p in programList] #remove #include lines
    programList = [re.sub(r'using\s*?namespace.*?;', "", p) for p in programList] #remove using std ... lines
    programList = [genericCodePrefix + p.strip() for p in programList]
    #TODO:uncomment------------
    TCdict = readTestCaseDict(datasetType)
    idList = readIDlist(datasetType)[:len(programList)]
    #print (len(programList), len(idList))
    assert len(programList) == len(idList)
    savePredictions(programList, writeDir)

    compilePass_perCode = [0 for i in range(len(programList))]
    TC_perCode = []
    TCpass_perCode = [0 for i in range(len(programList))]
    folderNm = makeSubDir(writeDir)

    compileProcesses = []
    for progIndx, program in tqdm(enumerate(programList), total=len(programList)):
        if len(program.strip()) == 0:
            compilePass_perCode[progIndx] = False
            print ("ERROR", flush = True)
            continue

        #get TC list
        TCkey = idList[progIndx]
        TClist = TCdict[TCkey]
        #print (progIndx, TClist)
        TC_perCode.append(len(TClist))

        filename = os.path.join(folderNm, 'main_{}.cpp'.format(progIndx))
        with open(filename, 'w', encoding='utf8') as fw:
            fw.write(program)

        outExeFilePath = os.path.join(folderNm, "a_{}.out".format(progIndx))
        command = ["g++", filename, "-o", outExeFilePath]
        f = tempfile.TemporaryFile()
        p = Popen(command, stderr = f)
        compileProcesses.append((p, f, TClist, progIndx, outExeFilePath))

        if (len(compileProcesses) >= 100) or (progIndx == len(programList) - 1):
            for p, f, TC_l, progIndx_tmp, exePath in compileProcesses:
                p.wait()
                f.seek(0)
                error_msg = f.read().decode("utf-8")
                f.close()

                compileSuccessFlag = False
                if len(error_msg) == 0:
                    compilePass_perCode[progIndx_tmp] = True
                    compileSuccessFlag = True
                    print ("pIndx {}: COMPILE-SUCCESS".format(progIndx_tmp), flush = True)
                else:
                    if "error:" in error_msg:
                        print (error_msg)
                        compilePass_perCode[progIndx_tmp] = False
                        print ("pIndx {}: COMPILE-ERROR".format(progIndx_tmp), flush = True)
                    else: #"warning:" or "note:"
                        compilePass_perCode[progIndx_tmp] = True
                        compileSuccessFlag = True
                        print ("pIndx {}: COMPILE-SUCCESS".format(progIndx_tmp), flush = True)

                if compileSuccessFlag:
                    command2 = [exePath]
                    for TC in TC_l:
                        try:
                            p2 = run(command2, input = TC[0].encode('utf-8'), stdout=PIPE, stderr=PIPE, timeout=5)
                            if p2.returncode != 0:
                                print ("pIndx {} RUNTIME-EXCEPTION".format(progIndx_tmp), flush = True)
                            console_out = p2.stdout.decode("utf-8")
                        except:
                            console_out = ""
                            print ("pIndx {} TIMEOUT".format(progIndx_tmp), flush = True)

                        if checkOutputsSame(console_out, TC[1]):
                            print ("pIndx {} RUNTIME-SUCCESS".format(progIndx_tmp), flush = True)
                            TCpass_perCode[progIndx_tmp] += 1
                        else:
                            print ("pIndx {} RUNTIME-ERROR".format(progIndx_tmp), flush = True)                   
            compileProcesses = []

    shutil.rmtree(folderNm, ignore_errors = True)
    assert len(programList) == len(compilePass_perCode)
    assert len(programList) == len(TCpass_perCode)

    print ("\n-----------Num Tested-----------", flush = True)
    print ("len(programList):", len(programList))

    print ("\n-----------CompAcc-----------", flush = True)
    print ("compilePass_perCode:", compilePass_perCode)
    Csuccess, Cerror = sum(compilePass_perCode), len(compilePass_perCode) - sum(compilePass_perCode)
    print('CSuccess - {}, CErrors - {}'.format(Csuccess, Cerror))
    CompAcc = (Csuccess * 100.0) / (Csuccess + Cerror)
    print (f"CompAcc: {Csuccess}/{(Csuccess + Cerror)} =", CompAcc, "%")

    print ("\n-----------FEqAcc-----------", flush = True)
    print ("TCpass_perCode:", TCpass_perCode)
    print ("TC_perCode:", TC_perCode)
    FEqAcc_perCode = [((TCpass_perCode[i] * 100.0) / TC_perCode[i]) for i in range(len(TC_perCode))]
    print ("FEqAcc_perCode:", FEqAcc_perCode)
    FEqAcc = sum(FEqAcc_perCode) / len(FEqAcc_perCode)
    print ("FEqAcc =", FEqAcc, "%")

    return {"CompAcc": CompAcc, "FEqAcc": FEqAcc}



if __name__ == '__main__':
    writeDir = "./extras_junkFiles"
    with open('./dataset/pseudocode-cpp/test.pseudocode-cpp.jsonl') as f:
        data = [json.loads(line) for line in f]

    cppProgList = []
    for dataInstance in data:
        cppProgList.append(dataInstance["completion"])

    check_CompAccRunEqAcc_cpp(cppProgList, writeDir, "pseudocode-cpp")
