import os
import json
import csv
from typing import Dict, Tuple, List
from collections import defaultdict

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate


class ScoreAggregator:
    """
    Processes evaluation JSON files to compute composite scores and normalized counts.
    """

    def __init__(
        self,
        eval_root: str,
        processed_dir: str,
    ):
        self.eval_root = eval_root
        self.processed_dir = processed_dir
        # Weights for composite metrics
        self.weights = {
            "structural_alignment": {
                "role_coverage": 0.5,
                "transition_logic": 0.5,
                "module_define_usage": 0.5,
            },
            "property_fidelity": {
                "coverage": 0.5,
                "logical_equivalence": 0.5,
                "operator_correctness": 0.5,
            },
            "semantic_fidelity": {
                "behavior_match": 0.5,
                "edge_case_handling": 0.5,
                "naming_clarity": 0.5,
            },
        }
        # Keys for count-based normalization
        self.count_keys = {
            "structural_alignment": ["exploration_count"],
            "property_fidelity": ["relevance_count"],
            "semantic_fidelity": ["penalty_count"],
            "conciseness": [
                "additional_concepts",
                "redundant_modules",
                "additional_properties",
            ],
        }
        # All count fields and their weights
        self.all_counts = [
            "exploration_count",
            "relevance_count",
            "penalty_count",
            "additional_concepts",
            "redundant_modules",
            "additional_properties",
        ]
        self.count_weights = {
            "exploration_count": 1,
            "relevance_count": 1,
            "penalty_count": -1,
            "additional_concepts": 1,
            "redundant_modules": -1,
            "additional_properties": 1,
        }
        # Model identifiers
        self.all_models = [f"eval_model_{i+1:02d}" for i in range(10)]

    def process(self) -> Dict[Tuple[str, str], Dict[str, float]]:
        """
        Read evaluation JSON files and compute raw combined metrics and counts.
        Returns a dictionary keyed by (SOP, model) -> metrics dict.
        """
        scores: Dict[Tuple[str, str], Dict[str, float]] = {}
        for folder in sorted(os.listdir(self.eval_root)):
            folder_path = os.path.join(self.eval_root, folder)
            if not os.path.isdir(folder_path):
                continue
            for fname in sorted(os.listdir(folder_path)):
                if not fname.endswith(".json"):
                    continue
                model_key = os.path.splitext(fname)[0]
                file_path = os.path.join(folder_path, fname)
                with open(file_path, "r") as f:
                    raw = f.read().strip()
                    if not raw:
                        continue
                    data = json.loads(raw)
                entry: Dict[str, float] = {"overall": 0.0, "combined_weight": 0.0}
                # Composite metrics
                for metric, wdict in self.weights.items():
                    weighted_sum = 0.0
                    sum_w = 0.0
                    for key, weight in wdict.items():
                        val = data.get(metric, {}).get("score", {}).get(key, 0)
                        weighted_sum += weight * val
                        sum_w += weight
                    entry[f"combined_{metric}"] = weighted_sum
                    entry[f"combined_{metric}_weight"] = sum_w
                    entry["overall"] += weighted_sum
                    entry["combined_weight"] += sum_w
                # Normalize overall
                if entry["combined_weight"] > 0:
                    entry["overall"] = round(
                        (entry["overall"] / entry["combined_weight"]) / 10, 2
                    )
                else:
                    entry["overall"] = 0.0
                # Raw counts
                for metric, keys in self.count_keys.items():
                    for key in keys:
                        entry[key] = data.get(metric, {}).get("score", {}).get(key, 0)
                scores[(folder, model_key)] = entry
        return scores

    def normalize_counts(
        self, score_dict: Dict[Tuple[str, str], Dict[str, float]]
    ) -> Dict[Tuple[str, str], Dict[str, float]]:
        """
        Normalize count metrics per model to [0,1] scale.
        """
        norm_scores: Dict[Tuple[str, str], Dict[str, float]] = defaultdict(dict)
        for model in self.all_models:
            # Filter by model
            filtered = {
                key: vals
                for key, vals in score_dict.items()
                if key[1] == model
            }
            if not filtered:
                continue
            for count in self.all_counts:
                values = [v.get(count, 0) for v in filtered.values()]
                min_v, max_v = min(values), max(values)
                for (sop, mdl), val in zip(filtered.keys(), values):
                    if max_v > min_v:
                        norm = (val - min_v) / (max_v - min_v)
                    else:
                        norm = 0.0
                    norm_scores[(sop, mdl)][f"norm_{count}"] = norm
        return norm_scores

    def compute_composite(
        self, score_dict: Dict[Tuple[str, str], Dict[str, float]]
    ) -> Dict[Tuple[str, str], Dict[str, float]]:
        """
        Compute final composite_score per (SOP, model).
        """
        for key, vals in score_dict.items():
            struct = vals.get("combined_structural_alignment", 0.0)
            w_struct = vals.get("combined_structural_alignment_weight", 1.0)
            prop = vals.get("combined_property_fidelity", 0.0)
            w_prop = vals.get("combined_property_fidelity_weight", 1.0)
            sem = vals.get("combined_semantic_fidelity", 0.0)
            w_sem = vals.get("combined_semantic_fidelity_weight", 1.0)
            comp = ((struct / w_struct) + (prop / w_prop) + (sem / w_sem)) / 3.0 / 10.0
            vals["composite_score"] = round(comp, 2)
        return score_dict

    def write_table(
        self,
        score_dict: Dict[Tuple[str, str], Dict[str, float]],
        norm_counts: Dict[Tuple[str, str], Dict[str, float]],
    ) -> None:
        """
        Generate and save a CSV and print a formatted table of final scores.
        """
        headers = [
            "Model",
            "SOP",
            "structural_alignment",
            "property_fidelity",
            "semantic_fidelity",
            "code_bonus",
            "code_compliance",
        ]
        rows: List[List] = []
        os.makedirs(self.processed_dir, exist_ok=True)
        for (sop, mdl), vals in sorted(score_dict.items()):
            base_struct = vals.get("combined_structural_alignment", 0.0)
            w_struct = vals.get("combined_structural_alignment_weight", 1.0)
            base_prop = vals.get("combined_property_fidelity", 0.0)
            w_prop = vals.get("combined_property_fidelity_weight", 1.0)
            base_sem = vals.get("combined_semantic_fidelity", 0.0)
            w_sem = vals.get("combined_semantic_fidelity_weight", 1.0)
            sa = round(base_struct / w_struct, 2)
            pf = round(base_prop / w_prop, 2)
            sf = round(base_sem / w_sem, 2)
            # code bonus = average weighted norm counts
            bonus_vals = [
                norm_counts[(sop, mdl)].get(f"norm_{c}", 0.0) * self.count_weights[c]
                for c in self.all_counts
            ]
            code_bonus = round(sum(bonus_vals) / len(self.all_counts), 2)
            # code compliance = mix of overall and norm penalty
            code_comp = round(
                0.8 * vals.get("overall", 0.0) + 0.2 * sum(bonus_vals) / len(self.all_counts),
                2,
            )
            rows.append([mdl, sop, sa, pf, sf, code_bonus, code_comp])
            # save per-model JSON
            out_dir = os.path.join(self.processed_dir, sop)
            os.makedirs(out_dir, exist_ok=True)
            data = {"composite_score": vals.get("composite_score", 0.0)}
            for c in self.all_counts:
                data[f"norm_{c}"] = norm_counts.get((sop, mdl), {}).get(f"norm_{c}", 0.0)
            with open(os.path.join(out_dir, f"{mdl}.json"), "w") as f:
                json.dump(data, f, indent=4)
        df = pd.DataFrame(rows, columns=headers).sort_values(["SOP", "Model"])
        print(tabulate(df, headers=headers, tablefmt="fancy_grid"))
        out_csv = os.path.join(self.processed_dir, "composite_scores.csv")
        df.to_csv(out_csv, index=False)
        print(f"✅ CSV file saved to: {out_csv}")


class ScoreAggregatorV2:
    """
    Processes evaluation JSON v2 for count-based metrics and computes execution compliance.
    """

    def __init__(
        self,
        eval_root: str,
        processed_dir: str,
    ):
        self.eval_root = eval_root
        self.processed_dir = processed_dir
        self.count_keys = {"counts": ["counterexample_traces", "minor_issues"]}
        self.all_counts = ["counterexample_traces", "minor_issues"]
        self.all_models = [f"eval_model_{i+1:02d}" for i in range(10)]

    def process(self) -> Dict[Tuple[str, str], Dict[str, int]]:
        """
        Read JSON v2 files and extract raw count metrics.
        """
        scores: Dict[Tuple[str, str], Dict[str, int]] = {}
        for folder in sorted(os.listdir(self.eval_root)):
            folder_p = os.path.join(self.eval_root, folder)
            if not os.path.isdir(folder_p):
                continue
            for fname in sorted(os.listdir(folder_p)):
                if not fname.endswith(".json"):
                    continue
                mdl = os.path.splitext(fname)[0]
                path = os.path.join(folder_p, fname)
                with open(path, "r") as f:
                    raw = f.read().strip()
                    if not raw:
                        continue
                    data = json.loads(raw)
                entry: Dict[str, int] = {}
                for k in self.count_keys["counts"]:
                    entry[k] = int(round(data.get("counts", {}).get(k, 0)))
                entry["errors"] = data.get("counts", {}).get("errors")
                scores[(folder, mdl)] = entry
        return scores

    def normalize_counts(
        self, score_dict: Dict[Tuple[str, str], Dict[str, float]]
    ) -> Dict[Tuple[str, str], Dict[str, float]]:
        """
        Normalize count metrics per model to [0,1] scale.
        """
        norm_scores: Dict[Tuple[str, str], Dict[str, float]] = defaultdict(dict)
        for model in self.all_models:
            flt = {
                key: vals
                for key, vals in score_dict.items()
                if key[1] == model
            }
            if not flt:
                continue
            for cnt in self.all_counts:
                vals = [v.get(cnt, 0) for v in flt.values()]
                min_v, max_v = min(vals), max(vals)
                for (sop, mdl), val in zip(flt.keys(), vals):
                    norm = (val - min_v) / (max_v - min_v) if max_v > min_v else 0.0
                    norm_scores[(sop, mdl)][f"norm_{cnt}"] = round(norm, 2)
        return norm_scores

    def compute_compliance(
        self,
        raw_scores: Dict[Tuple[str, str], Dict[str, int]],
        normed: Dict[Tuple[str, str], Dict[str, float]],
    ) -> Dict[Tuple[str, str], float]:
        """
        Compute execution compliance: 1 - penalty from normalized counts and errors flag.
        """
        compliance: Dict[Tuple[str, str], float] = {}
        for key in sorted(normed):
            err = raw_scores.get(key, {}).get("errors")
            nc = normed.get(key, {})
            if err == "Yes":
                execution = 0.0
            else:
                cet = nc.get("norm_counterexample_traces", 0.0)
                mid = nc.get("norm_minor_issues", 0.0)
                penalty = cet * 0.8 + mid * 0.2
                execution = round(1.0 - penalty, 2)
            compliance[key] = execution
        return compliance

    def write_table(
        self,
        raw_scores: Dict[Tuple[str, str], Dict[str, int]],
        normed: Dict[Tuple[str, str], Dict[str, float]],
        compliance: Dict[Tuple[str, str], float],
    ) -> None:
        """
        Generate and save a CSV and print a formatted table of count-based scores.
        """
        headers = ["Model", "SOP"] + [f"norm_{c}" for c in self.all_counts] + [
            "errors", "execution_compliance"
        ]
        rows: List[List] = []
        os.makedirs(self.processed_dir, exist_ok=True)
        for (sop, mdl), vals in sorted(raw_scores.items()):
            nvals = normed.get((sop, mdl), {})
            row = [mdl, sop]
            row += [nvals.get(f"norm_{c}", 0.0) for c in self.all_counts]
            row.append(vals.get("errors", "Unknown"))
            row.append(compliance.get((sop, mdl), 0.0))
            rows.append(row)
            # per-model JSON
            odir = os.path.join(self.processed_dir, sop)
            os.makedirs(odir, exist_ok=True)
            out = {f"norm_{c}": nvals.get(f"norm_{c}", 0.0) for c in self.all_counts}
            out["errors"] = vals.get("errors")
            out["execution_compliance"] = compliance.get((sop, mdl))
            with open(os.path.join(odir, f"{mdl}.json"), "w") as f:
                json.dump(out, f, indent=4)
        df = pd.DataFrame(rows, columns=headers).sort_values(["SOP", "Model"]).reset_index(drop=True)
        print(tabulate(df, headers="keys", tablefmt="fancy_grid", showindex=False))
        csv_path = os.path.join(self.processed_dir, "count_scores.csv")
        df.to_csv(csv_path, index=False)
        print(f"✅ CSV file saved to: {csv_path}")


class ScoreMerger:
    """
    Merges composite and count-based scores into a final compliance table.
    """

    def __init__(
        self,
        v1_csv: str,
        v2_csv: str,
        output_path: str,
    ):
        self.v1_csv = v1_csv
        self.v2_csv = v2_csv
        self.output_path = output_path

    def merge(self) -> pd.DataFrame:
        df1 = pd.read_csv(self.v1_csv)
        df2 = pd.read_csv(self.v2_csv)
        df = pd.merge(df1, df2, on=["Model", "SOP"], how="inner")
        df_final = df[[
            "Model",
            "SOP",
            "combined_structural_alignment",
            "combined_property_fidelity",
            "combined_semantic_fidelity",
            "code_bonus",
            "code_compliance",
            "norm_counterexample_traces",
            "norm_minor_issues",
            "errors",
            "execution_compliance",
        ]]
        df_final.rename(
            columns={
                "combined_structural_alignment": "structural_alignment",
                "combined_property_fidelity": "property_fidelity",
                "combined_semantic_fidelity": "semantic_fidelity",
            },
            inplace=True,
        )
        df_final["Final_score"] = (
            df_final["code_compliance"] * 0.4
            + df_final["execution_compliance"] * 0.6
        ).round(2)
        df_final.sort_values(by=["SOP", "Model"], inplace=True)
        df_final.to_csv(self.output_path, index=False)
        print(f"✅ Final compliance table saved to {self.output_path}")
        return df_final


class QualityVisualizer:
    """
    Generates bar chart comparing debugging steps across models.
    """

    def __init__(
        self,
        data: pd.DataFrame,
        output_path: str,
    ):
        self.df = data
        self.output_path = output_path

    @classmethod
    def default(cls, output_path: str):
        data = {
            "SOP Name": [
                "Shuttle Guidance system",
                "PCI Bus protocol",
                "Robotics controller",
                "Priority Queue Buffer",
                "Traffic Collision Avoidance System",
                "Ring Oscillator",
                "Mutual Exclusion",
                "Semaphore",
                "Gigamax Cache Coherence protocol",
                "Bounded Reliable Protocol",
            ],
            "SpecMAS": [6, 10, 3, 7, 8, 8, 8, 5, 7, 16],
            "DeepSeek R1": [0, 1, 0, 3, 1, 1, 1, 0, 2, 9],
            "Gemini": [5, 7, 12, 3, 14, 2, 0, 1, 14, 14],
            "Claude": [0, 2, 1, 0, 1, 0, 1, 1, 6, 1],
            "Qwen": [2, 9, 9, 2, 13, 5, 1, 2, 1, 3],
        }
        df = pd.DataFrame(data).set_index("SOP Name").apply(pd.to_numeric)
        return cls(df, output_path)

    def plot(self) -> None:
        fig, ax = plt.subplots(figsize=(12, 8))
        self.df.plot(kind="bar", ax=ax, width=0.8)
        ax.set_xlabel("System Specifications (SOP Name)", fontsize=14)
        ax.set_ylabel("Debugging Steps (Errors Count)", fontsize=14)
        ax.set_title(
            "Comparison of Debugging Steps Across Models", fontsize=16
        )
        ax.grid(axis="y", linestyle="--", alpha=0.7)
        ax.set_ylim(bottom=-0.5)
        plt.xticks(rotation=45, ha="right")
        ax.legend(title="Models", bbox_to_anchor=(1.05, 1), loc="upper left")
        fig.tight_layout()
        fig.savefig(self.output_path, dpi=300, bbox_inches="tight")
        plt.close(fig)
        print(f"✅ Comparison plot saved to {self.output_path}")


class PropertyVisualizer:
    """
    Generates dual-axis slope graph for spec-quality vs execution compliance.
    """

    def __init__(
        self,
        excel_path: str,
        output_path: str,
    ):
        self.excel_path = excel_path
        self.output_path = output_path

    def plot(self) -> None:
        df = pd.read_excel(self.excel_path)
        df["avg_spec_quality"] = df[
            ["structural_alignment", "property_fidelity",
             "semantic_fidelity"]
        ].mean(axis=1)
        df["execution_compliance"] = df["execution_compliance"].clip(
            upper=1.0
        )
        models = df["Model"].tolist()
        avg_spec = df["avg_spec_quality"].tolist()
        exec_comp = df["execution_compliance"].tolist()
        fig, ax1 = plt.subplots(figsize=(6, 4))
        ax2 = ax1.twinx()
        for m, spec, exec_c in zip(models, avg_spec, exec_comp):
            ax1.plot([0, 1], [spec, exec_c * 10], marker='o', label=m)
        ax1.set_xticks([0, 1])
        ax1.set_xticklabels(
            ['Avg Spec‑Quality', 'Execution Compliance']
        )
        ax1.set_ylabel('Spec‑Quality Score (0–10)', color='blue')
        ax2.set_ylabel('Execution Compliance (0–1)', color='green')
        ax1.set_ylim(0, 10)
        ax2.set_ylim(0, 1)
        ax1.set_title(
            'Spec‑Quality → Execution Compliance Slope Graph'
        )
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        fig.tight_layout()
        fig.savefig(self.output_path, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"✅ Slope graph saved to {self.output_path}")


class PearsonVisualizer:
    """
    Generates radar charts for metrics by domain from final evaluation CSV.
    """

    def __init__(
        self,
        csv_path: str,
        out_dir: str,
    ):
        self.csv_path = csv_path
        self.out_dir = out_dir
        # Map SOP to domain
        self.domain_map = {
            "model_10": "Protocols",
            "model_02": "Protocols",
            "model_07": "Sync/Concurrency",
            "model_08": "Sync/Concurrency",
            "model_09": "Memory Mgmt",
            "model_04": "Memory Mgmt",
            "model_06": "Digital Logic",
            "model_03": "Reactive Ctrl",
            "model_01": "Reactive Ctrl",
            "model_05": "Safety-Critical",
        }
        self.domains = [
            "Protocols",
            "Sync/Concurrency",
            "Memory Mgmt",
            "Digital Logic",
            "Reactive Ctrl",
            "Safety-Critical",
        ]
        self.metrics = [
            "Final_score",
            "execution_compliance",
            "code_compliance",
            "structural_alignment",
            "property_fidelity",
            "semantic_fidelity",
        ]

    def generate(self) -> None:
        os.makedirs(self.out_dir, exist_ok=True)
        df = pd.read_csv(self.csv_path)
        df.columns = df.columns.str.strip()
        df["Model"] = df["Model"].str.strip()
        df["SOP_clean"] = df["SOP"].str.replace("eval_", "").str.strip()
        df["Domain"] = df["SOP_clean"].map(self.domain_map)
        agg = (
            df.groupby(["Model", "Domain"] )[self.metrics]
            .mean()
        )
        # Radar chart angles
        theta = np.linspace(0, 2*np.pi, len(self.domains), endpoint=False)
        angles = np.concatenate((theta,[theta[0]]))
        for metric in self.metrics:
            fig, ax = plt.subplots(
                figsize=(7,7), subplot_kw={"polar": True}
            )
            ax.set_theta_offset(np.pi/2)
            ax.set_theta_direction(-1)
            ax.set_xticks(theta)
            ax.set_xticklabels(self.domains, fontsize=9)
            ax.set_yticks([0.2,0.4,0.6,0.8,1.0])
            ax.set_yticklabels(["0.2","0.4","0.6","0.8","1.0"], fontsize=8)
            ax.set_ylim(0,1)
            ax.grid(True, linestyle=':', linewidth=0.6)
            for model in agg.index.get_level_values(0).unique():
                values = []
                for dm in self.domains:
                    try:
                        val = agg.loc[(model, dm)][metric]
                    except KeyError:
                        val = 0.0
                    values.append(val)
                data = values + [values[0]]
                ax.plot(angles, data, label=model, linewidth=2)
                ax.fill(angles, data, alpha=0.12)
            ax.set_title(metric.replace('_',' ').title(), pad=20, fontsize=13)
            ax.legend(loc='upper right', bbox_to_anchor=(1.2,1.05))
            fig.tight_layout()
            out_file = os.path.join(self.out_dir, f"{metric}_radar.png")
            fig.savefig(out_file, dpi=300)
            plt.close(fig)
            print(f"✅ Radar chart for {metric} saved to {out_file}")


if __name__ == "__main__":
    # Aggregate v1
    agg1 = ScoreAggregator("./Eval_out/Evaluations", "./Eval_out/ProcessedScores")
    raw1 = agg1.process()
    norm1 = agg1.normalize_counts(raw1)
    comp1 = agg1.compute_composite(raw1)
    agg1.write_table(comp1, norm1)
    # Aggregate v2
    agg2 = ScoreAggregatorV2("./Eval_out/Evaluations_v2", "./Eval_out/Processed_v2")
    raw2 = agg2.process()
    norm2 = agg2.normalize_counts(raw2)
    comp2 = agg2.compute_compliance(raw2, norm2)
    agg2.write_table(raw2, norm2, comp2)
    # Merge
    merger = ScoreMerger(
        "./Eval_out/ProcessedScores/composite_scores.csv",
        "./Eval_out/Processed_v2/count_scores.csv",
        "./Eval_out/final_compliance_scores.csv"
    )
    merger.merge()
    # Visualize quality and properties
    qv = QualityVisualizer.default("./comparison.png")
    qv.plot()
    pv = PropertyVisualizer("./case study.xlsx", "./case_study_dual_axis.png")
    pv.plot()
    pv2 = PearsonVisualizer(
        "./final_compliance_scores.csv", "./radar_wheels"
    )
    pv2.generate()
