import json
import jsonlines

import glob
import os
import sys
from numpy import mean
from scipy.stats import sem, norm
from evaluators import get_evaluator
from grammar.grammar_parser import ACPGrammarParser

GRAMMAR_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "grammar","acp_grammar.lark")


def remove_garbage(s):
    while True:
        if s.endswith("."):
            s=s[:-1]
        elif s.endswith("\n"):
            s=s[:-2]
        else:
            break
    return s.rstrip()

def compare_str(s1, s2):
    return remove_garbage(s1).lower() == remove_garbage(s2).lower()
    
def compare(l1, l2):
    if not isinstance(l1, list):
        return compare_str(l1,l2)
    if not isinstance(l2, list):
        return False
    for i, v in enumerate (l1):
        if not compare(v, l2[i]):
            return False
    return True

def get_type(file):
    return "gen" if "_gen_" in file else None

def check_prog_response(resp):
    if "Positive Effects".lower() in resp.lower() and "Negative Effects".lower() in resp.lower():
        if "[" not in resp:
            return True
    return False

def get_parsed_answer(resp, task, parser):
    if "acp_prog_gen" in task:
        # Check for Positive Effects and Negative Effects instead of separation
        if check_prog_response(resp):
            # replace **Positive Effects** with "["
            # replace **Negative Effects** with "] ["
            # append "]" to the end
            resp2 = resp.lower()
            resp2 = resp2.replace("*","")
            resp2 = resp2.replace("positive effects","[")
            resp2 = resp2.replace("negative effects","] [")
            resp2 = resp2 + "]"
            return parser.parse(resp2)
    if "acp_just_gen" in task:
        # Check for "simplified plan:"
        if "simplified plan:" in resp.lower():
            resp2 = resp.lower()
            resp2 = resp2.replace("*","")
            resp2 = resp2.split("simplified plan:")[1]
            return parser.parse(resp2)
    return parser.parse(resp)

def _task_name(file):
    return "acp_" + os.path.basename(file).split("_acp_")[1].replace(".jsonl","")

def get_subtasks(task):
    if task == "acp_gen":
        return ["acp_just_gen", "acp_areach_gen", "acp_prog_gen", "acp_reach_gen", "acp_val_gen", "acp_land_gen", "acp_app_gen", "acp_nexta_gen"]
    if task == "acp_gen_cot":   
        return ["acp_just_gen_cot", "acp_areach_gen_cot", "acp_prog_gen_cot", "acp_reach_gen_cot", "acp_val_gen_cot", "acp_land_gen_cot", "acp_app_gen_cot", "acp_nexta_gen_cot"]
    if task == "acp_gen_cot_2shot":
        return ["acp_just_gen_cot_2shot", "acp_areach_gen_cot_2shot", "acp_prog_gen_cot_2shot", "acp_reach_gen_cot_2shot", "acp_val_gen_cot_2shot", "acp_land_gen_cot_2shot", "acp_app_gen_cot_2shot", "acp_nexta_gen_cot_2shot"]
    if task == "acp_gen_2shot":
        return ["acp_just_gen_2shot", "acp_areach_gen_2shot", "acp_prog_gen_2shot", "acp_reach_gen_2shot", "acp_val_gen_2shot", "acp_land_gen_2shot", "acp_app_gen_2shot", "acp_nexta_gen_2shot"]
    print("Task " + task + " is not defined!") 
    exit(1)


def get_keys_from_results_data(data):
    # Getting the key (starts with "exact_match_stderr," and "exact_match,")
    for k in data.keys():
        if k.startswith("exact_match,") or k.startswith("exact_match_mean_k,"):
            yield k
    for k in data.keys():
        if k.startswith("exact_match_stderr,") or k.startswith("exact_match_mean_k_stderr,"):
            yield k


def get_task_name(tasks, task):
    matches = [t for t in tasks if t.startswith(task)]
    if len(matches) == 1:
        return matches[0]
    if len(matches) > 1:
        return min(matches, key=len)
    return None

def get_grammar_task(task):
    task = task.split('_')[1]
    # print(task)
    if task == 'reach':
        return 'act'
    elif task == 'prog':
        return 'progression_list'
    elif task == 'val':
        return 'index'
    elif task == 'areach':
        return 'act'
    elif task == 'just':
        return 'action_list'
    elif task == 'land':
        return 'act'
    elif task == 'nexta':
        return 'action_name'
    elif task == 'app':
        return 'action_list'

if __name__ == "__main__":            
    dir = sys.argv[1]
    scores_by_tasks = {}
    data = {}
    for file in glob.glob(f"{dir}/*.jsonl"):
        fname = os.path.basename(file)
        print(f"===== Begin =======")
        print(f"file: {file}")
        assert get_type(fname) == "gen", f'Task type is {get_type(fname)}'
        scores = []
        task = _task_name(file)
        assert  os.path.isfile(GRAMMAR_FILE), f"The grammar file is missing at {GRAMMAR_FILE}"
        parser = ACPGrammarParser(GRAMMAR_FILE, get_grammar_task(task))

        num_parser_errors = 0
        with jsonlines.open(file) as f:
            results = {}
            for line in f.iter():
                group = line["doc"]["group"]
                evaluator = get_evaluator(group)
                ans = [get_parsed_answer(resp, task, parser) for resp in line["resps"][0]]
                if any(elem is None for elem in ans) or any(elem is None for elem in ans[0]):
                    num_parser_errors += 1
                    line["filtered_resps"] = "PARSER ERROR"
                    score = 0
                    scores.append(0)
                    results[line['doc']['id']]=0
                    continue
                if isinstance(line["resps"][0], list):
                    ans = [ans]
                if not compare(ans,line["filtered_resps"]):
                    line["filtered_resps"] = ans

                if isinstance(line["resps"][0], list):
                    score = evaluator.get_score(ans[0], line["doc"])
                else:
                    score = evaluator.get_score(ans, line["doc"])
                line["exact_match_mean_k"] = score
                scores.append(score)
                results[line['doc']['id']]=score

            print(f"The number of parser errors is {num_parser_errors}")
            scores_by_tasks[task] = scores 
            data[task]=results.copy()
            # print(data)
            # print(task, mean(scores), sem(scores))
        print(f"====== End =====\n")
        
    print("Computed scores by tasks:")
    json_results = {}
    json_results["scores"]= {}
    for rname in glob.glob(f"{dir}/results*.json"):
        res = json.load(open(rname))
        for task, results in res["results"].items():
            if task == 'acp':
                continue        
            if task in ["acp_gen", "acp_gen_cot_2shot", "acp_gen_cot","acp_gen_2shot"]:
                scores = []
                for s in get_subtasks(task):      
                    scores.extend(scores_by_tasks[s])
            else:
                task_name = get_task_name(scores_by_tasks.keys(), task)
                if task_name is None:
                    continue
                scores = scores_by_tasks[task_name]
                json_results[task] = data[task_name]
            print(task, mean(scores), sem(scores))
            json_results["scores"][task]= {"mean": mean(scores), "std": sem(scores)}

    with open('evaluation_results.json', 'w') as outfile:
        json.dump(json_results, outfile, indent=2)
            

