import json, re

def postprocess_analytical_entailment(text):
    pattern = r'The answer is:? (?:\\)?(?:\([\w\s]+\))?[^.?!\n]*[.?!]?'
    matches = re.findall(pattern, text, re.I)
    try:
        answer = matches[0]
    except:
        answer = ""
    try:
        answer = re.findall(r"\byes\b", answer, re.IGNORECASE)[0].lower()
    except:
        answer = "no"
    # print(answer)
    return answer


def postprocess_known_unknowns(text):
    response_line = text.split(".")[0].strip()
    response_line = response_line.replace("\"", "").replace("{", "").replace("}", "").replace(":", "").replace("\\", "")
    response_line = response_line.lower().replace("the answer is", "").strip()
    answer = re.sub(r"[ \t]+", " ", response_line)
    choices_present = re.findall(r'\((\d+)\)', answer)
    answer = (answer, "raw")
    if(len(choices_present)>0):
        answer = (choices_present[0], "option")
    return answer

from dateutil.parser import parse

def postprocess_date_understanding(text):
    pattern = r'The answer is:? (?:\\)?(?:\([\w\s]+\))?[^.?!\n]*[.?!]?'
    matches = re.findall(pattern, text, re.I)
    try:
        answer = matches[0]
    except:
        answer = ""
    date = parse(answer, fuzzy=True)
    answer = date.strftime('%m/%d/%Y')
    print('>>>>', text, '>>>', answer)
    return answer

import datefinder
def postprocess_date_understanding_zero_shot(text):
    print("text", text)
    matches = list(datefinder.find_dates(text))
    if(matches):
        answer = matches[0].strftime('%m/%d/%Y')
    else:
        # print("--------------> No valid date in the text: ", text)
        answer = ""
    # date = parse(text.replace("MM/DD/YYYY", " ").replace("mm/dd/yyyy", " ").replace("MM/DD/YY", " ").strip().split("would")[-1], fuzzy=True)
    # answer = date.strftime('%m/%d/%Y')
    # print("--answer--", answer)
    print(f"For response {text} we extract {answer}")
    return answer


def postprocess_date_understanding_few_shot(text):
    # print("text", text)
    text = text.split("Answer")[-1].replace(".", "")
    # print("-->", text)
    matches = list(datefinder.find_dates(text))
    if(matches):
        answer = matches[0].strftime('%m/%d/%Y')
    else:
        # print("--------------> No valid date in the text: ", text)
        answer = ""
    # date = parse(text.replace("MM/DD/YYYY", " ").replace("mm/dd/yyyy", " ").replace("MM/DD/YY", " ").strip().split("would")[-1], fuzzy=True)
    # answer = date.strftime('%m/%d/%Y')
    # print("--answer--", answer)
    print(f"-->For response {text} we extract -->{answer}<<--")
    return answer




def postprocess_gsm8k_zero_shot(text):
    text = text.replace(",", "")
    match = re.search(r'\\boxed\{(-?\d+(\.\d+)?)\}', text)
    if match:
        boxed_value = match.group(1)
        # print(boxed_value)  # Output: 366
        answer = boxed_value
    else:
        lines = [line for line in text.split("\n") if line.strip()]

        last_sentence = lines[-1]
        matches = re.findall(r'(\d+(\.\d+)?)', last_sentence)
        if matches:
            boxed_value = matches[-1][0]
            answer = boxed_value
        else:
            print("Handle this: ", lines)
            answer = ""
    return answer.strip()


def postprocess_anachronisms_zero_shot(text):
    answer = text
    choices_present = re.findall(r'\(?(\d+)\)?', text)
    answer = (answer, "raw")
    # print(f"->{answer[0]}<-")
    if(len(choices_present)>0 and len(choices_present[0])<=2):
        print("-->>", choices_present)
        answer = (choices_present[0], "option")
    else:
        try:
            answer = (re.findall(r"\byes\b", answer[0], re.IGNORECASE)[0].lower(), 'raw')
            print(answer)
        except Exception as e:
            # print(e, answer)
            answer = ("no", 'raw')
    # print(f"->{answer}<-\n{'*'*100}")
    return answer


def postprocess_anachronisms_few_shot(text):
    text = text.split("Answer")[-1].replace(".", "")
    text = text.replace(".", "").replace(":", "").strip()
    print(f"->{text}<-")
    return text


def postprocess_anachronisms_ins_only(text):
    # print(text)
    answer = re.findall(r"\byes\b", text, re.IGNORECASE)
    if(len(answer)>0):
        answer = answer[0].lower()
    else:
        answer = re.findall(r"\b1\b", text, re.IGNORECASE)
        if(len(answer)>0):
            answer = "yes"
    if(answer!="yes"):
        answer = "no"
    print("For response: ", text, " we extract: ", answer)
    return answer

def postprocess_anachronisms(text):
    pattern = r'The answer is:? (?:\\)?(?:\([\w\s]+\))?[^.?!\n]*[.?!]?'
    matches = re.findall(pattern, text, re.I)
    try:
        # print("answer found")
        answer = matches[0].replace("\\", "").replace(".", "")
    except:
        print("eeeee", text)
        answer = ""
    choices_present = re.findall(r'\(?(\d+)\)?', answer)
    answer = (answer, "raw")
    # print(f"->{answer[0]}<-")
    if(len(choices_present)>0 and len(choices_present[0])<=2):
        print(choices_present)
        answer = (choices_present[0], "option")
    else:
        try:
            answer = (re.findall(r"\byes\b", answer[0], re.IGNORECASE)[0].lower(), 'raw')
        except Exception as e:
            # print(e, answer)
            answer = ("no", 'raw')
    # print(f"->{answer}<-\n{'*'*100}")
    return answer

def get_mcqs(choices):
    choice = [x.lower() for x, y in choices.items() if y==1][0]
    return choice

def calculate_accuracy(golds, predicted):
    print([(gold, predicted) for gold, predicted in zip(golds, predicted) if gold!=predicted])
    print(sum([gold==predicted for gold, predicted in zip(golds, predicted)]))
    print([i for i, (g, p) in enumerate(zip(golds, predicted)) if g!=p])
    return sum([gold==predicted for gold, predicted in zip(golds, predicted)])/len(golds)


def evaluate_known_unknowns(prediction_file, gold_file):
    predicted, gold = [], []
    predictions = json.load(open(prediction_file))
    golds = json.load(open(gold_file))["examples"]
    for pred_id in predictions:
        # print(predictions[pred_id].keys())
        responses = predictions[pred_id]["final_response"]
        responses = [postprocess_known_unknowns(x) for x in responses]
        predicted.append(responses)
    for gold_c in golds:
        gold_options = get_mcqs(gold_c["target_scores"])
        gold.append(((gold_options), gold_c["target_scores"]))
    pp, gg = [], []
    for p, g in zip(predicted, gold):
        # print(p, g)
        (p, p_type), (g, g_options) = (p[0][0], p[0][1]), (g[0], g[1])
        if(p_type=="option"):
            g_options = dict([(str(idx+1), key) for idx, key in enumerate(g_options.keys())])
            p = g_options[p].lower()
        pp.append(p); gg.append(g)
        # print(f"Predicted: {p} Gold: {g}")
    print(f"Accuracy: {calculate_accuracy(gg, pp)}")

def evaluate_date_understanding(pred_file, gold_file):
    turn_accuracies = {}
    predicted, gold = [], []
    predictions = json.load(open(pred_file))
    golds = json.load(open(gold_file))["examples"]
    for pred_id in predictions:
        # print(predictions[pred_id].keys())
        responses = pred_id["ins_only_output"]
        responses = postprocess_date_understanding(responses)
        predicted.append(responses)
    for gold_c in golds:
        gold_options = get_mcqs(gold_c["target_scores"])
        gold.append(((gold_options), gold_c["target_scores"]))
    pp, gg = [], []
    for p, g in zip(predicted, gold):
        pp.append(p)
        gg.append(g[0])
    print(f"Accuracy: {calculate_accuracy(gg, pp)}")


def evaluate_anachronisms(pred_file, gold_file, answer_field = "ins_only_response"):
    turn_accuracies = {}
    predicted, gold = [], []
    predictions = json.load(open(pred_file))
    golds = json.load(open(gold_file))["examples"]
    for pred_id in predictions:
        # print(predictions[pred_id].keys())
        print(pred_id)
        responses = pred_id["ins_only_output"]
        responses = postprocess_anachronisms(responses) 
        predicted.append(responses)
    for gold_c in golds:
        gold_options = get_mcqs(gold_c["target_scores"])
        gold.append(((gold_options), gold_c["target_scores"]))
    pp, gg = [], []
    for p, g in zip(predicted, gold):
        print(p, g)
        # (p), (g, g_options) = (p), (g[0], g[1])
        pp.append(p[0].lower()); gg.append(g[0].lower())
        # print(f"Predicted: {p} Gold: {g}")
    print(f"Accuracy: {calculate_accuracy(gg, pp)}")



def evaluate_anlytical_entailment(pred_file, gold_file):
    turn_accuracies = {}
    predicted, gold = [], []
    predictions = json.load(open(pred_file))
    golds = json.load(open(gold_file))["examples"]
    for pred_id in predictions:
        # print(predictions[pred_id].keys())
        responses = predictions[pred_id]["final_response"]
        responses = [postprocess_analytical_entailment(x) for x in responses]
        predicted.append(responses)
    for gold_c in golds:
        gold_options = get_mcqs(gold_c["target_scores"])
        gold.append(((gold_options), gold_c["target_scores"]))
    pp, gg = [], []
    for p, g in zip(predicted, gold):
        print("no-entailment" if p[0] == "no" else "entailment", g[0])
        pp.append("no-entailment" if p[0] == "no" else "entailment")
        gg.append(g[0])
    print(f"Accuracy: {calculate_accuracy(gg, pp)}")

import os

def clean_hm_eval_tests(code):
    lines = code.split('\n')
    inside_function_or_class = False
    cleaned_lines = []

    for line in lines:
        stripped_line = line.strip()

        if stripped_line.startswith("def ") or stripped_line.startswith("class "):
            inside_function_or_class = True

        if inside_function_or_class and (not stripped_line or stripped_line == '}'):
            inside_function_or_class = False

        if "print(" in stripped_line and not inside_function_or_class:
            continue

        cleaned_lines.append(line)

    return '\n'.join(cleaned_lines).strip()


def clean_human_eval(programs, g_p):
    match = re.search(r"```\s*Python\s*(.*?)```", programs, re.DOTALL | re.IGNORECASE)
    if match:
        python_code = match.group(1)
        # print(python_code)
    else:
        python_code = ""
    _ = os.system("clear")
    print("-"*100)
    print("Please create the file named: ", g_p["task_id"])
    
    print("-"*100)
    print(programs)
    print("-"*100)
    print(python_code)
    print("-"*100)
    print("Further cleaning")
    print(clean_hm_eval_tests(python_code))
    print("-"*100)
    print("---", g_p["prompt"])
    _ = input("Testing")

def postprocess_human_eval(prediction_file, gold_file):
    jsn = json.load(open(prediction_file))
    gold_jsn = [json.loads(x) for x in open(gold_file)]
    assert len(jsn) == len(gold_jsn)
    all_programs = []
    for p_id, g_data in zip(jsn, gold_jsn):
        attempts = []
        programs = jsn[p_id]["final_response"]
        for program in programs:
            clean_human_eval(program, g_data)



def postprocess_gsm8k(text):
    pattern = r'The answer is:? (?:\\)?(?:\([\w\s]+\))?[^.?!\n]*[.?!]?'
    matches = re.findall(pattern, text, re.I)
    try:
        answer = matches[0].replace(":", "").replace("\\", "").replace("\"", "").replace("the answer is answer{}", "").replace("The answer is answer{}", "")
        print(answer)
        answer = int(re.findall("\d+", answer)[0].replace(",", ""))
    except:
        answer = ""
    return answer


def postprocess_gsm8k_few_shot(text):
    text = text.replace(",", "")
    matches = re.findall(r'<<[^>]+=([-]?[\d]+[./]?[\d]*)>>', text)
    # print(text, matches)
    if matches:
        answer = matches[-1] # Outputs: 230.5
    else:
       pattern = r'[-]?[\d]+[./]?[\d]*'
       print('*'*200)
       print("text is:", text)
       print('*'*200)
       last_sentence = text.strip().split('.')[-2].strip()
       print('-'*100)
       print("last last_sentence is:", text,">>>>", last_sentence)
       print('-'*100)
       matches = re.findall(pattern, last_sentence)
       if matches:
        answer = matches[-1]
       else:
        print("Handle this: ", text)
        answer = ""
    return answer

def evaluate_gsm8k(pred_file, gold_file):
    turn_accuracies, xyz = {}, 0
    predicted, gold = [], []
    predictions = json.load(open(pred_file))
    golds = [json.loads(x) for x in open(gold_file)]
    for pred_id in predictions:
        # print(predictions[pred_id].keys())
        responses = pred_id["few_shot_output"]
        responses = postprocess_gsm8k_few_shot(responses)
        predicted.append([responses])
    for gold_c in golds:
        gold_solution = gold_c["answer"].split("\n####")[-1].strip()
        gold.append(float(gold_solution.replace(",", "")))
    pp, gg, cc,cnt = [], [], 0, 0
    for idx, (p, g) in enumerate(zip(predicted, gold)):
        print(cnt, '>>>', p[0], g, (float(p[0]) if p[0]!="" else p[0])==g)
        cc += (float(p[0]) if p[0]!="" else p[0])==g
        cnt+=1
        # if(len(p[0])<=2):
        #     # print("....>>>>>>>", p[0])
        #     xyz += 1
        # print(idx, p, g)
        # if(idx==601):
        #     break
    print('>>>', xyz, cc/cnt)
    # print(f"Accuracy: {calculate_accuracy(gg, pp)}")
    

# pred_file = "promptd_output/date_understanding/date_understanding_gpt-4_final.json"
gold_file = "data/date_understanding/date_understanding.json"

# pred_file = "promptd_output/unkown_unknowns/hallucinations_gpt-4_final.json"
# gold_file = "data/known_unknowns/known_unknowns_halluicinations.json"

# pred_file = "promptd_output/analytical_entailment/analytical_entailment_gpt-4_final.json"
# gold_file = "data/analytical_entailment/analytical_entailment.json"



# pred_file = "promptd_output/human_eval/human_instructions_gpt-4_attempt_1.json"
# gold_file = "data/human_eval/HumanEval.jsonl"###mind this


# pred_file = "promptd_output/gsm8k/gsm8k_few_shot_gpt4.json"
# gold_file = "data/gsm8k/test_gsm8k.jsonl"###mind this


# pred_file = ""
if __name__=="__main__":
    # evaluate_known_unknowns("promptd_output/date_understanding/date_understanding_gpt-4_zero_shot_final.json", gold_file)#Accuracy: 0.8913043478260869
    evaluate_date_understanding("outputs/date_understanding_gpt-4_ins_only_ins_only.json", gold_file)#Accuracy: 0.7235772357723578
    # evaluate_anlytical_entailment(pred_file, gold_file)#Accuracy: 0.8285714285714286
    # evaluate_anachronisms("promptd_output/anachronisms/anachronisms_gpt-4_ins_only_ins_only.json", "data/anachronisms/anachronisms.json", "ins_only_response")#Accuracy: 0.8217391304347826
    # postprocess_human_eval(pred_file, gold_file)
    # evaluate_gsm8k(pred_file, gold_file)
    
    '''
    anachornisms turbo rewritten --- Accuracy: 0.6826086956521739

    '''
