from collections import defaultdict
from tqdm import tqdm

def eval_pred_list(pred_list, answer_processor):
    def toScore(b): 
        return 1.0 if b else 0.0

    def avg(lst): 
        return (sum(lst) / len(lst)) if lst else 0.0

    def getWordsNum(question):
        q = (question.get("question") or "").strip()
        return len(q.split()) if q else 0

    def getStepsNum(question):
        steps = 0
        for c in question.get("semantic", []) or []:
            op = f"{c.get('operation')}".strip()
            arg = f"{c.get('argument')}".strip()
            sig = f"{op}: {arg}"
            if any(s in sig for s in ["exist", "query: name", "choose name"]):
                continue
            steps += 1
        return steps

    def chi_square(goldDist, predictedDist):
        sumScore, sumOverall = 0.0, 0.0
        for group in goldDist:
            score, overall = 0.0, 0.0
            goldAns = goldDist[group]
            predAns = predictedDist.get(group, {})
            for ans, e in goldAns.items():
                if e <= 0:
                    continue
                o = predAns.get(ans, 0)
                score += ((float(o - e) ** 2) / float(e))
                overall += e
            sumScore += score * overall
            sumOverall += overall
        if sumOverall == 0:
            return 0.0
        return float(sumScore) / float(sumOverall)

    scores = {
        "accuracy": [],
        "binary": [],
        "open": [],
        "accuracyPerStructuralType": defaultdict(list),
        "accuracyPerSemanticType": defaultdict(list),
        "accuracyPerLength": defaultdict(list),
        "accuracyPerSteps": defaultdict(list),
    }
    dist = {
        "gold": defaultdict(lambda: defaultdict(int)),
        "predicted": defaultdict(lambda: defaultdict(int)),
    }
    eval_qa = {}

    for entry in tqdm(pred_list):
        pred_answer = answer_processor(entry["pred_answer"])
        question = entry["question"]

        if not question.get("isBalanced", True):
            continue

        correct = (pred_answer == question["answer"])
        s = toScore(correct)

        wordsNum = getWordsNum(question)
        stepsNum = getStepsNum(question)
        structural = question.get("types", {}).get("structural", "unknown")
        semantic   = question.get("types", {}).get("semantic", "unknown")

        scores["accuracy"].append(s)
        scores["accuracyPerLength"][wordsNum].append(s)
        scores["accuracyPerSteps"][stepsNum].append(s)
        scores["accuracyPerStructuralType"][structural].append(s)
        scores["accuracyPerSemanticType"][semantic].append(s)

        answerType = "open" if structural == "query" else "binary"
        scores[answerType].append(s)

        g = question.get("groups", {}).get("global", None)
        if g is not None:
            dist["gold"][g][question["answer"]]       += 1
            dist["predicted"][g][pred_answer] += 1

        eval_qa[entry["question_id"]] = s * 100

    out_scores = {}

    for k, v in scores.items():
        if isinstance(v, list):
            out_scores[k] = avg(v) * 100.0
        else:
            out_scores[k] = {kk: (avg(vals) * 100.0, len(vals)) for kk, vals in v.items()}

    out_scores["distribution"] = chi_square(dist["gold"], dist["predicted"]) / 100.0

    return out_scores, dist, eval_qa