import argparse
import json
import os
import importlib
from SPARQLWrapper import SPARQLWrapper, JSON, SPARQLExceptions
from tqdm import tqdm
import sys


kb_endpoint = "http://10.201.38.151:3001/sparql"

def parse_sparql_results(response):
    if "boolean" in response:
        return [response["boolean"]]
    bindings = response.get("results", {}).get("bindings", [])

    if len(bindings) > 0 and "callret-0" in bindings[0]:
        return [int(bindings[0]["callret-0"]["value"])]

    results = []
    for row in bindings:
        for v in row.values():
            results.append(v["value"].replace('http://rdf.freebase.com/ns/', ""))
    return results


def execute_sparql(query: str):
    clean_q = query.replace("FILTER(NOT EXISTS", "FILTER NOT EXISTS")
    clean_q = "\n".join(line.strip() for line in clean_q.splitlines() if line.strip())

    sparql = SPARQLWrapper(kb_endpoint)
    sparql.setReturnFormat(JSON)
    sparql.setQuery(clean_q)
    try:
        resp = sparql.query().convert()
        return parse_sparql_results(resp)
    except SPARQLExceptions.QueryBadFormed as e:
        # print(f"[SPARQL Error] - Skip")
        return None
    except Exception as e:
        # print(f"[SPARQL Error]")
        return None


def sparql_evaluate(preds, gold_res_list):
    skipped = total = 0
    sum_p = sum_r = sum_f = 0.0

    for gold_res, pred_query in tqdm(zip(gold_res_list, preds), total=len(gold_res_list), desc='Evaluating CompWebQ'):
        # print(f"Gold res: {gold_res}")
        if gold_res is None:
            skipped += 1
            continue

        pred_res = execute_sparql(pred_query)
        # print(f"Pred res: {pred_res}")

        gold_set = set(gold_res)
        pred_set = set(pred_res) if pred_res else set()

        if not gold_set and not pred_set:
            p = r = f = 1.0
        else:
            tp = len(gold_set & pred_set)
            p = tp / len(pred_set) if pred_set else 0.0
            r = tp / len(gold_set) if gold_set else 0.0
            f = 2 * p * r / (p + r) if (p + r) > 0 else 0.0

        sum_p += p
        sum_r += r
        sum_f += f
        total += 1

    avg_p = sum_p / total if total else 0.0
    avg_r = sum_r / total if total else 0.0
    avg_f = sum_f / total if total else 0.0

    return {
        "avg_precision": f"{avg_p:.4f}",
        "avg_recall": f"{avg_r:.4f}",
        "avg_f1": f"{avg_f:.4f}",
        "total_valid": total,
        "skipped": skipped,
    }

def sparql_evaluate_backup(preds, test_data):
    skipped = total = 0
    sum_p = sum_r = sum_f = 0.0

    for gold_item, pred_query in tqdm(zip(test_data, preds), total=len(test_data), desc='Evaluating CompWebQ'):
        gold_q = gold_item
        if not gold_q:
            skipped += 1
            continue

        gold_res = execute_sparql(gold_q, kb_endpoint)
        print(f"Gold res: {gold_res}")
        if not gold_res:
            skipped += 1
            continue

        pred_res = execute_sparql(pred_query, kb_endpoint)
        print(f"Pred res: {pred_res}")

        gold_set = set(gold_res)
        pred_set = set(pred_res) if pred_res else set()

        if not gold_set and not pred_set:
            p = r = f = 1.0
        else:
            tp = len(gold_set & pred_set)
            p = tp / len(pred_set) if pred_set else 0.0
            r = tp / len(gold_set) if gold_set else 0.0
            f = 2 * p * r / (p + r) if (p + r) > 0 else 0.0

        sum_p += p; sum_r += r; sum_f += f
        total += 1

    avg_p = sum_p / total if total else 0.0
    avg_r = sum_r / total if total else 0.0
    avg_f = sum_f / total if total else 0.0

    return {
        "avg_precision": f"{avg_p:.4f}",
        "avg_recall":    f"{avg_r:.4f}",
        "avg_f1":        f"{avg_f:.4f}",
        "total_valid":   total,
        "skipped":       skipped
    }


def evaluate(preds, golds):
    return sparql_evaluate(preds, golds)["avg_f1"]
