from __future__ import annotations
from collections import defaultdict
from pathlib import Path

from common.task import Domain, Task
from common.verdict import SolverReply
import csv

def _pct(n, d):
    return (n / d) * 100.0 if d and d > 0 else None


class OverallResults:
    results_verifier: ResultCounter
    results_nl_judge: ResultCounter
    results_equivalence_judge: ResultCounter
    results_nl_base_judge: ResultCounter

    def __init__(
        self,
        results_verifier: ResultCounter,
        results_nl_judge: ResultCounter,
        results_equivalence_judge: ResultCounter,
        results_nl_base_judge: ResultCounter
    ):
        self.results_verifier = results_verifier
        self.results_nl_judge = results_nl_judge
        self.results_equivalence_judge = results_equivalence_judge
        self.results_nl_base_judge = results_nl_base_judge

    def record_verifier(self, domain: Domain, task: Task, reply: SolverReply):
        self.results_verifier.record(domain, task, reply)

    def record_nl_judge(self, domain: Domain, task: Task, reply: SolverReply):
        self.results_nl_judge.record(domain, task, reply)

    def record_equivalence_judge(self, domain: Domain, task: Task, reply: SolverReply):
        self.results_equivalence_judge.record(domain, task, reply)

    def summarize_all(self):
        print("=" * 60)
        print("All Results".center(60))
        print("=" * 60)

        print("\n" + "─" * 60)
        print("Verifier Results:")
        print("─" * 60)
        self.results_verifier.summarize()

        print("\n" + "─" * 60)
        print("Base NL Judge Results:")
        print("─" * 60)
        self.results_nl_base_judge.summarize()

        print("\n" + "─" * 60)
        print("NL Judge:")
        print("─" * 60)
        self.results_nl_judge.summarize()

        print("\n" + "─" * 60)
        print("Equivalence Juge:")
        print("─" * 60)
        self.results_equivalence_judge.summarize()

        print("\n" + "=" * 60)
        print("Conditional Probability".center(60))
        print("=" * 60)

        self._print_conditional_probabilities("NL Judge", self.results_nl_judge)
        self._print_conditional_probabilities(
            "Equivalence Judge", self.results_equivalence_judge
        )

    def _print_conditional_probabilities(
        self, judge_name: str, judge_results: ResultCounter
    ):
        print(f"\n{judge_name} Conditional Probabilities:")
        print("─" * 50)

        verifier_results = self.results_verifier.results
        judge_data = judge_results.results

        for domain in verifier_results:
            print(f"\n  Domain {domain.value}:")
            verifier_map = verifier_results[domain]
            judge_map = judge_data[domain]

            refuted_count = 0  # solver = failure
            success_count = 0  # solver = success
            refuted_decided = 0  # solver = failure and judge != unknown
            success_decided = 0  # solver = success  and judge != unknown

            false_positive = 0  # solver=failure & judge=success
            false_negative = 0  # solver=success & judge=failure
            judge_accepts_decided = 0  # judge=success on labeled & decided

            for task_id, verifier_reply in verifier_map.items():
                judge_reply = judge_map[task_id]

                if verifier_reply.verdict == "unknown":
                    continue

                if verifier_reply.verdict == "failure":
                    refuted_count += 1
                elif verifier_reply.verdict == "success":
                    success_count += 1

                if judge_reply.verdict == "unknown":
                    continue

                if verifier_reply.verdict == "failure":
                    refuted_decided += 1
                    if judge_reply.verdict == "success":
                        false_positive += 1
                        judge_accepts_decided += 1
                elif verifier_reply.verdict == "success":
                    success_decided += 1
                    if judge_reply.verdict == "failure":
                        false_negative += 1
                    elif judge_reply.verdict == "success":
                        judge_accepts_decided += 1

            decided_labeled = refuted_decided + success_decided
            total_labeled = refuted_count + success_count

            # FAR
            p_accept_given_refuted = (
                (false_positive / refuted_decided) if refuted_decided > 0 else None
            )
            # FRR
            p_reject_given_success = (
                (false_negative / success_decided) if success_decided > 0 else None
            )

            judge_accept_rate = (
                (judge_accepts_decided / decided_labeled)
                if decided_labeled > 0
                else None
            )
            solver_accept_rate = (
                (success_decided / decided_labeled) if decided_labeled > 0 else None
            )
            inflation = (
                (judge_accept_rate - solver_accept_rate)
                if judge_accept_rate is not None and solver_accept_rate is not None
                else None
            )

            # These are coverages, but we don't present them in the paper:
            cov_ref = (refuted_decided / refuted_count) if refuted_count > 0 else None
            cov_succ = (success_decided / success_count) if success_count > 0 else None

            print(
                f"    Labeled (solver): {total_labeled} | Judge-decided: {decided_labeled}",
                end="",
            )
            if cov_ref is not None and cov_succ is not None:
                print(f"  (coverage: Ref={cov_ref:.2f}, Succ={cov_succ:.2f})")
            else:
                print()

            if p_accept_given_refuted is not None:
                print(
                    f"    P(accept | solver=Refuted, decided) = {p_accept_given_refuted:.3f}  ({false_positive}/{refuted_decided})"
                )
            else:
                print("    P(accept | solver=Refuted, decided) = n/a")

            if p_reject_given_success is not None:
                print(
                    f"    P(reject | solver=Success, decided) = {p_reject_given_success:.3f}  ({false_negative}/{success_decided})"
                )
            else:
                print("    P(reject | solver=Success, decided) = n/a")

            if inflation is not None:
                print(
                    f"    Inflation (Judge-Solver on decided) = {inflation:.3f}  "
                    f"[Judge accept: {judge_accept_rate:.3f}, Solver accept: {solver_accept_rate:.3f}]"
                )
            else:
                print("    Inflation (Judge-Solver on decided) = n/a")

    def write_csvs(self, model_name: str, out_dir: str):
        Path(out_dir).mkdir(parents=True, exist_ok=True)
        self._write_solver_summary_csv(model_name, out_dir)
        self._write_judge_metrics_csv(model_name, out_dir, judge_kind="nl_judge", judge_results=self.results_nl_judge)
        self._write_base_judge_metrics_csv(model_name, out_dir, judge_kind="nl_base_judge", judge_results=self.results_nl_base_judge)
        self._write_judge_metrics_csv(model_name, out_dir, judge_kind="equivalence_judge", judge_results=self.results_equivalence_judge)

    def _write_solver_summary_csv(self, model_name: str, out_dir: str):
        path = Path(out_dir) / f"{model_name}_solver_summary.csv"
        path.parent.mkdir(exist_ok=True, parents=True)
        with open(path, "w+", newline="") as f:
            w = csv.writer(f)
            w.writerow(["model","domain","accept_pct","refuted_pct","unknown_pct",
                        "accept_n","refuted_n","unknown_n","total_n","ce_rate_pct","ce_n"])
            for domain, dmap in self.results_verifier.results.items():
                total = len(dmap)
                accept_n = sum(1 for r in dmap.values() if r.verdict == "success")
                refuted_n = sum(1 for r in dmap.values() if r.verdict == "failure")
                unknown_n = sum(1 for r in dmap.values() if r.verdict == "unknown")
                ce_n = sum(1 for r in dmap.values()
                           if r.verdict == "failure" and (r.counterexample is not None))
                w.writerow([
                    model_name, domain.value,
                    _pct(accept_n, total), _pct(refuted_n, total), _pct(unknown_n, total),
                    accept_n, refuted_n, unknown_n, total,
                    _pct(ce_n, refuted_n), ce_n
                ])

    def _write_base_judge_metrics_csv(self, model_name: str, out_dir: str, judge_kind: str, judge_results: 'ResultCounter'):
        path = Path(out_dir) / f"{model_name}_{judge_kind.lower()}_metrics.csv"
        path.parent.mkdir(exist_ok=True, parents=True)
        with open(path, "w+", newline="") as f:
            w = csv.writer(f)
            w.writerow([
                "model","judge_kind","domain", "acceptance_rate", "unknown_pct",
                "accept_n", "refuted_n", "unknown_n", "total_n"
            ])
            
            for domain, dmap in judge_results.results.items():
                total = len(dmap)
                accept_n = sum(1 for r in dmap.values() if r.verdict == "success")
                refuted_n = sum(1 for r in dmap.values() if r.verdict == "failure")
                unknown_n = sum(1 for r in dmap.values() if r.verdict == "unknown")
                
                w.writerow([
                    model_name, judge_kind, domain.value,
                    _pct(accept_n, total), _pct(unknown_n, total),
                    accept_n, refuted_n, unknown_n, total
                ])


    def _write_judge_metrics_csv(self, model_name: str, out_dir: str, judge_kind: str, judge_results: 'ResultCounter'):
        path = Path(out_dir) / f"{model_name}_{judge_kind.lower()}_metrics.csv"
        path.parent.mkdir(exist_ok=True, parents=True)
        with open(path, "w+", newline="") as f:
            w = csv.writer(f)
            w.writerow([
                "model","prompt","domain",
                "inflation_pct","far_pct","frr_pct",
                "judge_accept_rate_pct","solver_accept_rate_pct",
                "coverage_ref_pct","coverage_succ_pct",
                "decided_labeled","total_labeled",
                "refuted_decided","success_decided",
                "refuted_count","success_count",
                "fp","fn","tp","tn"
            ])

            v_all = self.results_verifier.results
            j_all = judge_results.results

            for domain, vmap in v_all.items():
                jmap = j_all[domain]
                ref_total = succ_total = 0
                ref_dec = succ_dec = 0
                # false positive, false negative, etc
                fp = fn = tp = tn = 0

                for tid, vrep in vmap.items():
                    jrep = jmap[tid]
                    if vrep.verdict == "unknown":
                        # to exclude unlabeled, do not change this! It's
                        # a part of eliminating verifier replies
                        continue

                    if vrep.verdict == "failure":
                        ref_total += 1
                    else:
                        succ_total += 1

                    if jrep.verdict == "unknown":
                        continue

                    if vrep.verdict == "failure":
                        ref_dec += 1
                        if jrep.verdict == "success": fp += 1
                        else: tn += 1
                    else:
                        succ_dec += 1
                        if jrep.verdict == "failure": fn += 1
                        else: tp += 1

                decided_labeled = ref_dec + succ_dec
                total_labeled = ref_total + succ_total

                far = _pct(fp, ref_dec)
                frr = _pct(fn, succ_dec) 
                judge_accept_rate = _pct(tp + fp, decided_labeled)
                solver_accept_rate = _pct(succ_dec, decided_labeled)
                inflation = (judge_accept_rate - solver_accept_rate) if (judge_accept_rate is not None and solver_accept_rate is not None) else None
                cov_ref = _pct(ref_dec, ref_total)
                cov_succ = _pct(succ_dec, succ_total)

                w.writerow([
                    model_name, judge_kind, domain.value,
                    inflation, far, frr,
                    judge_accept_rate, solver_accept_rate,
                    cov_ref, cov_succ,
                    decided_labeled, total_labeled,
                    ref_dec, succ_dec, ref_total, succ_total,
                    fp, fn, tp, tn
                ])

class ResultCounter:
    results: dict[Domain, dict[int, SolverReply]]

    def __init__(self):
        self.results = defaultdict(dict)

    def record(self, domain: Domain, task: Task, reply: SolverReply):
        self.results[domain][task.id_in_domain] = reply

    def summarize(self):
        if not self.results:
            print("No results")
            return

        for domain, domain_results in self.results.items():
            total = len(domain_results)
            correct = sum(
                1 for reply in domain_results.values() if reply.verdict == "success"
            )
            unknown = sum(
                1 for reply in domain_results.values() if reply.verdict == "unknown"
            )
            failure = sum(
                1 for reply in domain_results.values() if reply.verdict == "failure"
            )
            has_counterexamples = sum(
                1
                for reply in domain_results.values()
                if reply.verdict == "failure" and reply.counterexample is not None
            )

            unknown_ids = [
                task_id
                for task_id, reply in domain_results.items()
                if reply.verdict == "unknown"
            ]

            percent_correct = correct / total if total > 0 else 0.0
            print(
                f"Domain {domain.value}: {correct}/{total} correct ({percent_correct:.2%}), {unknown} unknown, {has_counterexamples} has counterexamples from {failure} failures"
            )

            if unknown_ids:
                print(f"  Unknown task IDs: {sorted(unknown_ids)}")

        print()
        print("┌" + "─" * 40 + "┐")
        print("│" + "Overall Accuracy per Domain:".center(40) + "│")
        print("├" + "─" * 40 + "┤")
        accuracies = self.accuracy_per_domain()
        for domain, accuracy in accuracies.items():
            line = f"  {domain.value}: {accuracy:.2%}"
            print("│" + line.ljust(40) + "│")
        print("├" + "─" * 40 + "┤")
        overall_acc = self.overall_accuracy()
        overall_line = f"Overall Accuracy: {overall_acc:.2%}"
        print("│" + overall_line.center(40) + "│")
        print("└" + "─" * 40 + "┘")

    def accuracy_per_domain(self) -> dict[Domain, float]:
        accuracies = {}
        for domain, domain_results in self.results.items():
            total = len(domain_results)
            correct = sum(
                1 for reply in domain_results.values() if reply.verdict == "success"
            )
            accuracy = correct / total if total > 0 else 0.0
            accuracies[domain] = accuracy
        return accuracies

    def overall_accuracy(self) -> float:
        domain_accuracies = self.accuracy_per_domain()
        if not domain_accuracies:
            return 0.0
        return sum(domain_accuracies.values()) / len(domain_accuracies)

    def success_count_per_domain(self) -> dict[Domain, int]:
        counts = {}
        for domain, domain_results in self.results.items():
            success_count = sum(
                1 for reply in domain_results.values() if reply.verdict == "success"
            )
            counts[domain] = success_count
        return counts

    def unknown_count_per_domain(self) -> dict[Domain, int]:
        counts = {}
        for domain, domain_results in self.results.items():
            success_count = sum(
                1 for reply in domain_results.values() if reply.verdict == "unknown"
            )
            counts[domain] = success_count
        return counts
