import os
import sys
import re
import string
import tempfile
from tqdm import tqdm
from subprocess import run, check_output, CalledProcessError, STDOUT, PIPE, Popen, TimeoutExpired
import shutil, time, random
import json
import jsonlines
from transformers import GemmaTokenizer
import numpy as np
import copy

def makeSubDir(writeDir, devc):
    subfolderExists = True
    while subfolderExists:
        time_rand_str = str(int(round(time.time() * 1000))) + "_" + str(random.randint(1000,9999)) + "_" + str(devc)
        folderNm = os.path.join(writeDir, time_rand_str)
        subfolderExists = os.path.exists(folderNm)
        if not subfolderExists:
            os.makedirs(folderNm, exist_ok = False)
    return folderNm

def check_CompAcc_Boolean_cpp(programList, writeDir, device = "0"):
    genericCodePrefix =\
    '''#include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <ctime>
    #include <cstdlib>
    #include <algorithm>
    #include <vector>
    #include <string>
    #include <map>
    #include <unordered_map>
    #include <set>
    #include <unordered_set>
    #include <deque>
    #include <random>
    #include <stack>
    #include <queue>
    #include <tuple>
    #include <list>
    #include <climits>
    #include <cassert>
    using namespace std;
    '''
    programList = [re.sub(r'#include\s*?<.*?>\s*?', "", p) for p in programList] #remove #include lines
    programList = [re.sub(r'using\s*?namespace.*?;', "", p) for p in programList] #remove using std ... lines
    programList = [genericCodePrefix + p for p in programList]

    compilePass_perCode = [0 for i in range(len(programList))]
    folderNm = makeSubDir(writeDir, device)

    compileProcesses = []
    for progIndx, program in tqdm(enumerate(programList), total=len(programList)):
        if len(program.strip()) == 0:
            compilePass_perCode[progIndx] = False
            #print ("ERROR", flush = True)
            continue

        filename = os.path.join(folderNm, 'main_{}.cpp'.format(progIndx))
        with open(filename, 'w', encoding='utf8') as fw:
            fw.write(program)

        outExeFilePath = os.path.join(folderNm, "a_{}.out".format(progIndx))
        command = ["g++", filename, "-o", outExeFilePath]
        f = tempfile.TemporaryFile()
        p = Popen(command, stderr = f)
        compileProcesses.append((p, f, progIndx, outExeFilePath))

        if (len(compileProcesses) >= 100) or (progIndx == len(programList) - 1):
            for p, f, progIndx_tmp, exePath in compileProcesses:
                p.wait()
                f.seek(0)
                try:
                    error_msg = f.read().decode("utf-8")
                except:
                    error_msg = "error:"
                f.close()

                if len(error_msg) == 0:
                    compilePass_perCode[progIndx_tmp] = True
                else:
                    if "error:" in error_msg: #error
                        compilePass_perCode[progIndx_tmp] = False
                    else: #"warning:" or "note:"
                        compilePass_perCode[progIndx_tmp] = True             
            compileProcesses = []

    shutil.rmtree(folderNm, ignore_errors = True)

    print ("\n-----------CompAcc-----------", flush = True)
    print ("compilePass_perCode:", compilePass_perCode)
    Csuccess, Cerror = sum(compilePass_perCode), len(compilePass_perCode) - sum(compilePass_perCode)
    print('CSuccess - {}, CErrors - {}'.format(Csuccess, Cerror))
    CompAcc = (Csuccess * 100.0) / (Csuccess + Cerror)
    print (f"CompAcc: {Csuccess}/{(Csuccess + Cerror)} =", CompAcc, "%")

    return [float(i) for i in compilePass_perCode]

def floatReplace_func(match):
    match = match.group()
    return "{:.1f}".format(float(match))

def checkOutputsSame(out1, out2):
    #CHECK CASE OF RANDOM
    out1 = out1.replace('\x00', '')
    out2 = out2.replace('\x00', '')
    if (out1 is None) or (out2 is None):
        return False
    if (len(out1.strip())==0) and (len(out2.strip())==0):
        return False
    if "".join(out1.split()).strip() == "".join(out2.split()).strip():  
        return True

    #replaces all floating points or integers --> uniform representation i.e. 1 digit after .
    out1 = re.sub(r'[-+]?[\d]*[\.][\d]+|[-+]?[\d]+', floatReplace_func, out1)
    #removes all punctuations
    out1 = out1.translate(str.maketrans('', '', string.punctuation))
    out1 = out1.replace("▁", " ")

    #replaces all floating points or integers --> uniform representation i.e. 1 digit after .
    out2 = re.sub(r'[-+]?[\d]*[\.][\d]+|[-+]?[\d]+', floatReplace_func, out2)
    #removes all punctuations
    out2 = out2.translate(str.maketrans('', '', string.punctuation))
    out2 = out2.replace("▁", " ")

    #removes all whitespaces
    out1List = out1.strip().split()
    out2List = out2.strip().split()
    #changes to lowercase
    out1 = "".join(out1List).lower()
    out2 = "".join(out2List).lower()
    #print ("out1,out2", out1, out2, flush = True)
    return out1 == out2

def getRewardForCorrectCodes(rewardVec, fileSepTokenIndx, FEqAcc, correctRewardScheme):
    print ("--------------------COMPILE-SUCCESS--------------------")
    print ("numTokens:", len(rewardVec))
    print ("fileSepTokenIndx", fileSepTokenIndx)
    print ("FEqAcc", FEqAcc)
    rewardVec[fileSepTokenIndx] = 0.0
    if fileSepTokenIndx < len(rewardVec) - 1:
        rewardVec[fileSepTokenIndx + 1: -1] = 0.0
    if correctRewardScheme == "correct_allOne":
        rewardVec[0: fileSepTokenIndx] = +2.0
        rewardVec[-1] = +2.0
    else: #correctRewardScheme == "correct_propTCpassed"
        rewardVec[0: fileSepTokenIndx] = FEqAcc + 1
        rewardVec[-1] = FEqAcc + 1
    print ("rewardVec", rewardVec)

def getRewardForIncorrectCodes(actualErrLineIndices, newLineTokenIndx_list,
                                    newLineTokenType_list, fileSepTokenIndx,
                                    rewardVec, tokenID_npList,
                                    wrongRewardScheme): #tokenizer
    print ("--------------------COMPILE-ERROR--------------------")
    print ("numTokens:", len(rewardVec))
    print ("actualErrLineIndices:", actualErrLineIndices)
    print ("newLineTokenIndx_list:", newLineTokenIndx_list)
    print ("newLineTokenType_list:", newLineTokenType_list)
    print ("fileSepTokenIndx:", fileSepTokenIndx)
    #print ("rewardVec:", rewardVec)
    #print ("\n\n")
    #for i in range(len(tokenID_npList)):
    #    print (i, tokenID_npList[i], tokenizer.decode([tokenID_npList[i]]))
    #print ("\n\n")
    dict_actualLineIndx_TokenstrtStop = {}
    if len(newLineTokenIndx_list) == 0:
        dict_actualLineIndx_TokenstrtStop[0] = (0, fileSepTokenIndx - 1)
    else:
        actualLineIndx = 0
        for i in range(len(newLineTokenIndx_list) + 1):
            if (i == 0):
                strtIndx = 0
                endIndx = newLineTokenIndx_list[i]
                countOfEndNewLines = newLineTokenType_list[i] - newLineTokenType_list[0] + 1
            elif (i == len(newLineTokenIndx_list)):
                strtIndx = newLineTokenIndx_list[i - 1] + 1
                endIndx = fileSepTokenIndx
                countOfEndNewLines = 0
            else:
                strtIndx = newLineTokenIndx_list[i - 1] + 1
                endIndx = newLineTokenIndx_list[i]
                countOfEndNewLines = newLineTokenType_list[i] - newLineTokenType_list[0] + 1
            #print (f"line-{i} (strt {strtIndx}, end {endIndx}, countNL {countOfEndNewLines}, actualLineIndx {actualLineIndx}):", 
            #        tokenizer.decode(tokenID_npList[strtIndx : endIndx + 1]).encode('utf8'))
            dict_actualLineIndx_TokenstrtStop[actualLineIndx] = (strtIndx, endIndx)
            actualLineIndx += countOfEndNewLines

    if len(actualErrLineIndices) == 0:
        actualErrLineIndices = [0]
    for errLineIndx in actualErrLineIndices:
        strt = 0
        if errLineIndx in dict_actualLineIndx_TokenstrtStop:
            strt, end = dict_actualLineIndx_TokenstrtStop[errLineIndx]
            rewardVec[strt: end + 1] = 0.0
        else:
            iterBack = copy.deepcopy(errLineIndx)
            while (iterBack not in dict_actualLineIndx_TokenstrtStop):
                iterBack -= 1
                if iterBack < 0:
                    break
            if iterBack in dict_actualLineIndx_TokenstrtStop:
                strt, end = dict_actualLineIndx_TokenstrtStop[iterBack]
                rewardVec[strt: end + 1] = 0.0
        if wrongRewardScheme == "wrong_penalizeAllAfterFirst":
            rewardVec[strt: len(tokenID_npList)] = 0.0
            rewardVec[fileSepTokenIndx] = 0.0
            print ("rewardVec", rewardVec)
            return
    if fileSepTokenIndx < len(tokenID_npList) - 1:
        rewardVec[fileSepTokenIndx + 1: -1] = 0.0
    #set last reward value
    rewardVec[fileSepTokenIndx] = 0.0
    rewardVec[-1] = 0.0
    print ("rewardVec", rewardVec)
             

def check_CompAcc_vector_cpp(programList, tokenID_tensorList, 
                                writeDir, TClist_perCode,
                                wrongRewardScheme = "wrong_penalizeLines", 
                                correctRewardScheme = "correct_propTCpassed", 
                                device = "0"):
    #wrongRewardScheme = ["wrong_penalizeLines", "wrong_penalizeAllAfterFirst"]
    #correctRewardScheme = ["correct_allOne", "correct_propTCpassed"]
    TOKENID_FILESEP = 70
    TOKENIDs_NEWLINE = [i for i in range(108, 139)]
    genericCodePrefix =\
    '''#include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <ctime>
    #include <cstdlib>
    #include <algorithm>
    #include <vector>
    #include <string>
    #include <map>
    #include <unordered_map>
    #include <set>
    #include <unordered_set>
    #include <deque>
    #include <random>
    #include <stack>
    #include <queue>
    #include <tuple>
    #include <list>
    #include <climits>
    #include <cassert>
    using namespace std;
    '''

    tokenID_npList = [t.numpy() for t in tokenID_tensorList]
    rewardVec_perCode = [np.ones_like(t, dtype=float) for t in tokenID_npList]
    #print ("rewardVec_perCode", rewardVec_perCode)

    #---------regarding <|file_separator|>---------
    fileSepTokenIndx_perCode = [np.where(t == TOKENID_FILESEP)[0] for t in tokenID_npList]
    fileSepTokenIndx_perCode = [f[0] if len(f) > 0 else (len(tokenID_npList[fIndx]) - 1) for\
                                 fIndx, f in enumerate(fileSepTokenIndx_perCode)]
    #print ("fileSepTokenIndx_perCode", fileSepTokenIndx_perCode)
    programList = [t.split("<|file_separator|>")[0].strip() for t in programList]

    #---------regarding new lines---------
    newLineTokenIndx_perCode = [np.where(np.isin(t, TOKENIDs_NEWLINE))[0] for t in tokenID_npList]
    for i in range(len(programList)):
        newLineTokenIndx_perCode[i] = newLineTokenIndx_perCode[i][\
                                np.where(newLineTokenIndx_perCode[i] < fileSepTokenIndx_perCode[i])]
    #print ("newLineTokenIndx_perCode", newLineTokenIndx_perCode)
    newLineTokenType_perCode = [] #type among TOKENIDs_NEWLINE
    for i in range(len(programList)):
        newLineTokenType_perCode.append(tokenID_npList[i][newLineTokenIndx_perCode[i]])
    #print ("newLineTokenType_perCode", newLineTokenType_perCode)

    #skipping #include lines...NOTE: new lines are preserved here
    programList = [re.sub(r'#include\s*?<.*?>\s*?', "", p) for p in programList] #remove #include lines
    programList = [re.sub(r'using\s*?namespace.*?;', "", p) for p in programList] #remove using std ... lines
    programList = [genericCodePrefix + p for p in programList]

    #for p in programList:
    #    print (p)

    compilePass_perCode = [0 for i in range(len(programList))]
    TCpass_perCode = [0 for i in range(len(programList))]
    TC_perCode = []
    actualErrLineIndices_perCode = [None for i in range(len(programList))]
    folderNm = makeSubDir(writeDir, device)

    compileProcesses = []
    for progIndx, program in tqdm(enumerate(programList), total=len(programList)):
        if len(program.strip()) == 0:
            compilePass_perCode[progIndx] = False
            #print ("ERROR", flush = True)
            continue
        TC_perCode.append(len(TClist_perCode[progIndx]))

        filename = os.path.join(folderNm, 'main_{}.cpp'.format(progIndx))
        with open(filename, 'w', encoding='utf8') as fw:
            fw.write(program)

        outExeFilePath = os.path.join(folderNm, "a_{}.out".format(progIndx))
        command = ["g++", filename, "-o", outExeFilePath]
        f = tempfile.TemporaryFile()
        p = Popen(command, stderr = f)
        compileProcesses.append((p, f, TClist_perCode[progIndx], progIndx, outExeFilePath))

        if (len(compileProcesses) >= 100) or (progIndx == len(programList) - 1):
            for p, f, TC_l, progIndx_tmp, exePath in compileProcesses:
                p.wait()
                f.seek(0)
                try:
                    error_msg = f.read().decode("utf-8")
                except:
                    error_msg = "error:"
                f.close()

                #print ("error_msg", progIndx_tmp, error_msg)
                if len(error_msg) == 0:
                    compilePass_perCode[progIndx_tmp] = True
                else:
                    if "error:" in error_msg: #error
                        compilePass_perCode[progIndx_tmp] = False
                        errlineNumsInCode = re.findall(r"(\d+?):\d+?: error:", error_msg)
                        #print ("errlineNumsInCode", errlineNumsInCode)
                        #print ("genericCodePrefix.count('\n')", genericCodePrefix.count('\n'))
                        actualErrLineNums = sorted(list(set([int(i) - genericCodePrefix.count('\n') \
                                                    for i in errlineNumsInCode]))) #1-indexed
                        #print ("actualErrLineNums", actualErrLineNums)
                        actualErrLineIndices_perCode[progIndx_tmp] = np.array(actualErrLineNums) - 1
                    else: #"warning:" or "note:"
                        compilePass_perCode[progIndx_tmp] = True

                if compilePass_perCode[progIndx_tmp]:
                    command2 = [exePath]
                    for TC in TC_l:
                        try:
                            p2 = run(command2, input = TC[0].encode('utf-8'), stdout=PIPE, stderr=PIPE, timeout=20)
                            console_out = p2.stdout.decode("utf-8")
                        except:
                            console_out = ""
                        if checkOutputsSame(console_out, TC[1]):
                            TCpass_perCode[progIndx_tmp] += 1                
 
            compileProcesses = []

    shutil.rmtree(folderNm, ignore_errors = True)
    #print ("actualErrLineIndices_perCode", actualErrLineIndices_perCode)
    FEqAcc_perCode = [((TCpass_perCode[i]) / TC_perCode[i]) for i in range(len(TC_perCode))]

    for progIndx in range(len(programList)):
        if not compilePass_perCode[progIndx]:
            getRewardForIncorrectCodes(actualErrLineIndices_perCode[progIndx], 
                                        newLineTokenIndx_perCode[progIndx],
                                        newLineTokenType_perCode[progIndx], 
                                        fileSepTokenIndx_perCode[progIndx],
                                        rewardVec_perCode[progIndx], 
                                        tokenID_npList[progIndx], 
                                        wrongRewardScheme)
        else:
            getRewardForCorrectCodes(rewardVec_perCode[progIndx], 
                                    fileSepTokenIndx_perCode[progIndx],
                                    FEqAcc_perCode[progIndx],
                                    correctRewardScheme)

    print ("\n-----------CompAcc-----------", flush = True)
    print ("compilePass_perCode:", compilePass_perCode)
    Csuccess, Cerror = sum(compilePass_perCode), len(compilePass_perCode) - sum(compilePass_perCode)
    print('CSuccess - {}, CErrors - {}'.format(Csuccess, Cerror))
    CompAcc = (Csuccess * 100.0) / (Csuccess + Cerror)
    print (f"CompAcc: {Csuccess}/{(Csuccess + Cerror)} =", CompAcc, "%")

    print ("\n-----------FEqAcc-----------", flush = True)
    print ("TCpass_perCode:", TCpass_perCode)
    print ("TC_perCode:", TC_perCode)
    print ("FEqAcc_perCode", FEqAcc_perCode)
    FEqAcc = sum(FEqAcc_perCode) / len(FEqAcc_perCode)
    print ("FEqAcc =", FEqAcc, "%")

    #print ("\n-----------rewardVec_perCode-----------", flush = True)
    #print (rewardVec_perCode, flush = True)

    return rewardVec_perCode


if __name__ == '__main__':
    #---------------------------------------------------
    #for check_CompAcc_vector_cpp
    #---------------------------------------------------
    writeDir = "./extras_junkFiles"
    cppProgList = [
        '''#include <iostream>
        #include <cmath>
        using namespace std;

        const int MIN = INT32_MIN;

        int main()
        {
        int n; cin >> n;
        int k; cin >> k;
        cout << "vector solution" << endl;
        if (k == MIN) cout << -1 << endl;
        else {
            int c = 0, b = 1, d = 0, e;
            for (int i = 1; i < n + 1; i++) {
                cin >> e;
                if (e >= k) continue;
                if (e < 8) c += e;
                else {
                    d += e - 8;
                    d *= -1;
                }

                cout << e << endl;

                if (c >= k) break;
                else if (e >= 8) c += 8;
                else {
                    cout << c << endl;
                    break;
                }

                while (b < n && d + b < 7) {
                    cout << d + b << endl;
                    b += 1;
                }
            }
            cout << b - 1 << endl;
        }
        return 0;
        }<|file_separator|>''',
        '''class Solution {
        public:
            string pseudo_to_code(string kei_nai_ko_sayek){

                let N be a constant integer with N = 1e2 + 5
                f = array of integers of length N

                let n, m, mn, mx be integers
                read n, m, mn, mx
                create a vector of integers with name v and size m
                for i = 0 to m exclusive
                    read v[i]
                    if v[i] is greater than mx or v[i] is less than mn , return 0 , print Incorrect and newline
                    f[v[i]] is equal to 1
                decrement n by m
                m is equal to (mx - mn + 1) - accumulate(f + mn, f + mx + 1, 0)
                if n is equal to 1 and not f[mn] and and f[mx] print Incorrect and newline else print Correct and newline

                return kei_nai_ko_sayek;
            }
        };

        <|file_separator|>class SVNPlayer & class PunishPlayer'. Null, Call All Players.py
        <|fim_prefix|><|fim_suffix|><|fim_middle|>Null Ping command ==> "all pong" 
        Null Punish command ==> "all pong"


        2 players hit each other, One player lose the game after having the card printed
        Game end question ==> If 2 players tank ==> blackjack ==> null Ping. Else ==> Exotic Pong.
        <|file_separator|>codeheaven.py
        <|fim_prefix|><|fim_suffix|><|fim_middle|># customer nth with unique barcode number(unique number) ==> 1 ..... NC SES or whatever.... twitch


        n = NCardNumberBarCode Number


        n == NDC
        n == SND
        no ==> call customer

        To find no ==>MANUFACTFORCING mackbook ==>鬼灭戴裙 elektron}<|file_separator|>Tic Tac Toe Game Example (Tic Tac Toe Bot Concept).py
        <|fim_prefix|><|fim_suffix|>
        turing = Player()

        player1.name = cust1
        player2.name = cust2
        turing.name = Hugo

        widgets.Object(bounds=[0.15, 0.09, 0.73, 0.9], text=player1.name[0:13] + " " + customerMsg + "\n" + player2.name[0:13] + " ", state='Discrete', stateVal=ta[index[p_cust2counter]]).showWindow()

        ''',
        '''const int maxn = 100000;
            int a[maxn], b[maxn], n, r = 1, ans;
            int main() {
            cin >> n;
            for (int i = 1; i <= n; i++) { cin >> a[i]; }
            sort(a + 1, a + 1 + n);
            for (int i = 1; i <= n; i++) {
            while (r <= n && a[r] <= a[i]) r++;
            if (r <= n) ans++, r++;
            }
            cout << ans << endl;
            return 0;
            }<|file_separator|> bla bla\n bla bla bla''',
        '''long long int a[100050] = {0};
        int main() {
        int p = 0, res = 0, ui = 0;
        int T;
        cin >> T;
        while (T--) cin >> a[p++];
        sort(a, a + p);
        for (int i = 0; i <= p - 1; i++) {
        for (int k = ui; k <= p - 1; k++)
        if (a[ui] <= a[i])
        ui++;
        else
        break;
        if (ui < p) res++, ui++;
        }
        cout << res << endl;
        return 0;
        }'''
    ]
    TClist =   [[
        [
            "2\n10 1 1 1 5 5 3",
            "4"
        ],
        [
            "5\n1 1 1 1 1",
            "0"
        ],
        [
            "6\n300000000 200000000 300000000 200000000 1000000000 300000000",
            "3"
        ],
        [
            "10\n1 2 3 4 5 6 7 8 9 10",
            "9"
        ],
        [
            "1\n1",
            "0"
        ],
        [
            "7\n3 5 2 2 5 2 4",
            "4"
        ],
        [
            "5\n1 5 4 2 3",
            "4"
        ]
    ],
    [
        [
            "2\n10 1 1 1 5 5 3",
            "4"
        ],
        [
            "5\n1 1 1 1 1",
            "0"
        ],
        [
            "6\n300000000 200000000 300000000 200000000 1000000000 300000000",
            "3"
        ],
        [
            "10\n1 2 3 4 5 6 7 8 9 10",
            "9"
        ],
        [
            "1\n1",
            "0"
        ],
        [
            "7\n3 5 2 2 5 2 4",
            "4"
        ],
        [
            "5\n1 5 4 2 3",
            "4"
        ]
    ],
    [
        [
            "2\n10 1 1 1 5 5 3",
            "4"
        ],
        [
            "5\n1 1 1 1 1",
            "0"
        ],
        [
            "6\n300000000 200000000 300000000 200000000 1000000000 300000000",
            "3"
        ],
        [
            "10\n1 2 3 4 5 6 7 8 9 10",
            "9"
        ],
        [
            "1\n1",
            "0"
        ],
        [
            "7\n3 5 2 2 5 2 4",
            "4"
        ],
        [
            "5\n1 5 4 2 3",
            "4"
        ]
    ],
    [
        [
            "2\n10 1 1 1 5 5 3",
            "4"
        ],
        [
            "5\n1 1 1 1 1",
            "0"
        ],
        [
            "6\n300000000 200000000 300000000 200000000 1000000000 300000000",
            "3"
        ],
        [
            "10\n1 2 3 4 5 6 7 8 9 10",
            "9"
        ],
        [
            "1\n1",
            "0"
        ],
        [
            "7\n3 5 2 2 5 2 4",
            "4"
        ],
        [
            "5\n1 5 4 2 3",
            "4"
        ]
    ]]

    model_id = "google/codegemma-2b" 
    tokenizer = GemmaTokenizer.from_pretrained(model_id, truncation_side = "left", 
                            padding_side = "left")
    tokenizer.pad_token = tokenizer.eos_token

    tokenID_list = [tokenizer.encode(x, return_tensors = 'pt')[0] for x in cppProgList]


    #print ("cppProgList", cppProgList)
    #print ("tokenID_list", tokenID_list)
    out = check_CompAcc_vector_cpp(cppProgList, tokenID_list, writeDir, TClist)
    #print ("out", out)

    '''
    #---------------------------------------------------
    #for check_CompAcc_Boolean_cpp
    #---------------------------------------------------
    writeDir = "./extras_junkFiles"
    with open('./dataset/pseudocode-cpp/test.pseudocode-cpp.jsonl') as f:
        data = [json.loads(line) for line in f]

    cppProgList = []
    for dataInstance in data:
        cppProgList.append(dataInstance["completion"])

    out = check_CompAcc_Boolean_cpp(cppProgList, writeDir)
    print ("out", out)
    '''
