import json
import json5
import re
from pathlib import Path
from collections import Counter
from typing import List, Dict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, confusion_matrix
from scipy.optimize import linear_sum_assignment

# -------------------------------
# Utility functions
# -------------------------------
def load_json(file_path: Path):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()
    after_think = None
    if "</think>" in text:
        after_think = text.split("</think>")[-1]
    search_region = after_think if after_think is not None else text
    match = re.search(r"```(?:json)?\s*({.*?})\s*```", search_region, re.DOTALL)
    if not match:
        match = re.search(r"({.*?})", search_region, re.DOTALL)
    if not match:
        return None
    try:
        return json5.loads(match.group(1))
    except Exception:
        return None

def normalize_path_list(x):
    if x is None:
        return []
    if isinstance(x, list):
        return [s.strip() for s in x]
    return [s.strip() for s in re.split(r"\s*->\s*|\s*>\s*|\s*→\s*", str(x)) if s.strip()]


def filter_path_17_node(lst):
    if not lst:
        return lst
    out = list(lst)
    if out and out[-1] == "NSCL-17-10" and len(out) > 1:
        out = out[:-1]
    if out and out[-1] == "NSCL-17-1" and len(out) > 1:
        out = out[:-1]
    return out

def compare_lists(paths_filtered):
    n = len(paths_filtered)
    if n < 2:
        return 1.0, 1.0, None, None
    path_frac = np.mean([paths_filtered[i] == paths_filtered[j] 
                         for i in range(n) for j in range(i+1, n)])
    treatment_frac = path_frac
    return path_frac, treatment_frac, None, None

# -------------------------------
# Clustering Analysis
# -------------------------------

def cluster_accuracy_and_f1(y_true, y_pred):
    """Best-match cluster labels to true labels, then compute accuracy and F1 (macro)."""
    from sklearn.metrics import accuracy_score, f1_score

    cm = confusion_matrix(y_true, y_pred)
    row_ind, col_ind = linear_sum_assignment(-cm)
    mapping = dict(zip(col_ind, row_ind))
    y_pred_aligned = np.array([mapping[label] for label in y_pred])

    acc = accuracy_score(y_true, y_pred_aligned)
    f1 = f1_score(y_true, y_pred_aligned, average="macro")
    return acc, f1, y_pred_aligned

def run_clustering_analysis(X, y_treat=None, n_clusters=2):
    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # KMeans clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(X_scaled)
    print(f"Cluster centroids:\n{kmeans.cluster_centers_}")

    metrics = {}
    cm = None
    aligned_labels = cluster_labels.copy()

    if y_treat is not None:
        ari = adjusted_rand_score(y_treat, cluster_labels)
        nmi = normalized_mutual_info_score(y_treat, cluster_labels)
        acc, f1, aligned_labels = cluster_accuracy_and_f1(y_treat, cluster_labels)
        cm = confusion_matrix(y_treat, aligned_labels)

        metrics = {
            "ARI": ari,
            "NMI": nmi,
            "Accuracy": acc,
            "F1 (macro)": f1
        }
        print(f"Adjusted Rand Index: {ari:.3f}, Normalized Mutual Info: {nmi:.3f}, Accuracy: {acc:.3f}, F1 (macro): {f1:.3f}")

    # PCA projection to 2D for visualization
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X_scaled)

    # --- Plot 1: Subplots of clusters colored by true vs predicted ---
    fig, axes = plt.subplots(1, 2, figsize=(14,6))
    
    if y_treat is not None:
        sns.scatterplot(ax=axes[0], x=X_pca[:,0], y=X_pca[:,1], hue=y_treat, palette="Set2", s=80)
        axes[0].set_title("Clusters Colored by True Labels")
    else:
        sns.scatterplot(ax=axes[0], x=X_pca[:,0], y=X_pca[:,1], hue=cluster_labels, palette="Set2", s=80)
        axes[0].set_title("Clusters Colored by Predicted Labels")
    
    sns.scatterplot(ax=axes[1], x=X_pca[:,0], y=X_pca[:,1], hue=cluster_labels, palette="Set2", s=80)
    axes[1].set_title("Clusters Colored by Predicted Labels")
    
    for ax in axes:
        ax.set_xlabel("PC1")
        ax.set_ylabel("PC2")
        ax.legend_.remove()  # remove legend to avoid clutter

    plt.tight_layout()
    plt.show()
    plt.savefig("./CancerGUIDE/results/figures/clustering_analysis.png")

    # --- Plot 2: Confusion Matrix ---
    if cm is not None:
        plt.figure(figsize=(6,5))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
        plt.title("Confusion Matrix (aligned clusters → true labels)")
        plt.xlabel("Predicted Cluster")
        plt.ylabel("True Label")
        plt.show()
        plt.savefig("./CancerGUIDE/results/figures/clustering_confusion_matrix.png")

    return cluster_labels, cm, metrics

def load_trend_data() -> Dict[str, Dict[str, float]]:
    """
    Load benchmark results and extract both average_treatment_match 
    and average_path_match metrics for each model/benchmark.

    Returns:
        Dict[str, Dict[str, float]]:
            model -> { "benchmark_metric": value, ... }
    """
    reformatted: Dict[str, Dict[str, float]] = {}
    data_path = Path("./CancerGUIDE/results/heatmap_results.json")

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for entry in data:
        model = entry["model"]
        benchmark = entry["benchmark"]

        # Skip unwanted benchmarks
        if benchmark in {"human_new_prompt", "treatment_aggregation", "path_aggregation"}:
            continue

        if model not in reformatted:
            reformatted[model] = {}

        for metric in ["average_treatment_match", "average_path_match"]:
            value = entry.get(metric)
            if value is None:
                continue  # skip if metric not present

            # Apply penalty for filter benchmarks
            if benchmark in {"path_filter", "treatment_filter"}:
                answered = entry.get("total_patients_matched", 0)
                penalty = answered / 121 if answered else 0
                value *= penalty

            # Unique key: benchmark + metric
            if "gpt-5_" in benchmark:
                benchmark = benchmark.split("gpt-5_")[1]
            key = f"{benchmark}_{metric}"
            reformatted[model][key] = float(value)

    return reformatted


# -------------------------------
# Analyzer Class
# -------------------------------
class AccuracyAnalyzer:
    def __init__(self, annotations_path: Path, patient_id_pool: int = 360):
        self.annotations_path = Path(annotations_path)
        self.patient_id_pool = patient_id_pool
        self.ground_truth = self.load_ground_truth()

    def load_ground_truth(self) -> Dict[str, List[str]]:
        ground_truth = {}
        for patient_id in range(self.patient_id_pool + 1):
            file_path = self.annotations_path / f"patient_{patient_id}.json"
            if not file_path.exists():
                continue
            try:
                with open(file_path, "r") as f:
                    data = json.load(f)
                if "label" in data:
                    ground_truth[str(patient_id)] = normalize_path_list(data["label"])
            except Exception as e:
                print(f"Error loading GT for patient {patient_id}: {e}")
        return ground_truth

    @staticmethod
    def results_dir_for_model(base_path: Path, model_name: str) -> Path:
        return base_path / "results" / "rollout_results_0815_benchmark" / f"rollout_experiment_{model_name}"

    def load_patient_paths(self, patient_dir: Path, pid: str, num_iterations: int = 10):
        paths = []
        for i in range(num_iterations):
            fpath = patient_dir / f"patient_{pid}_iteration_{i}.json"
            if not fpath.exists():
                continue
            data = load_json(fpath)
            if not data:
                continue
            if "final_path" in data:
                paths.append(normalize_path_list(data["final_path"]))
        return paths

    def collect_rows_for_model(self, results_dir: Path, model_name: str, num_iterations: int,
                               all_model_paths: Dict[str, Dict[str,List[str]]], model_trend_features: List[dict]):
        if not results_dir.exists():
            print(f"[WARN] Missing results_dir for {model_name}: {results_dir}")
            return

        for patient_dir in results_dir.glob("patient_*"):
            pid = patient_dir.name.split("_")[1]
            if pid not in self.ground_truth:
                continue
            gt_path = filter_path_17_node(self.ground_truth[pid])

            paths = self.load_patient_paths(patient_dir, pid, num_iterations=num_iterations)
            if not paths:
                continue

            paths_filtered = [tuple(filter_path_17_node(p)) for p in paths if p]
            if not paths_filtered:
                continue
            counter = Counter(paths_filtered)
            mode_path, mode_count = counter.most_common(1)[0]
            mode_path = list(mode_path)

            path_match_fraction_prediction, treatment_match_fraction_prediction, _, _ = compare_lists(paths_filtered)
            mode_frac = mode_count / max(1, num_iterations)

            # Aggregated feature
            match_count = 0
            total_models = 0
            for other_model, patient_paths in all_model_paths.items():
                if pid in patient_paths:
                    other_mode_path = patient_paths[pid]
                    total_models += 1
                    if other_mode_path[-1] == mode_path[-1]:
                        match_count += 1
            aggregated_score = match_count / max(1, total_models) if total_models > 0 else 0.0

            feats = {
                "path_match_fraction": path_match_fraction_prediction,
                "treatment_match_fraction": treatment_match_fraction_prediction,
                "aggregated": aggregated_score,
                **{benchmark: value for benchmark, value in model_trend_features.get(model_name).items()}
            }

            treat_match_target = int(gt_path and (gt_path[-1] == mode_path[-1]))
            path_overlap_target = compare_lists([mode_path, gt_path])[0]
            yield feats, (path_match_fraction_prediction, treatment_match_fraction_prediction), pid, model_name, treat_match_target, path_overlap_target

    def build_dataset(self, base_path: Path, model_names: List[str], model_trend_features: dict[str, List[dict]]):
        # Build all_model_paths for aggregation
        all_model_paths = {}
        for mname in model_names:
            rdir = self.results_dir_for_model(base_path, mname)
            patient_paths = {}
            n_iter = 10
            for patient_dir in rdir.glob("patient_*"):
                pid = patient_dir.name.split("_")[1]
                paths = self.load_patient_paths(patient_dir, pid, num_iterations=n_iter)
                if not paths:
                    continue
                paths_filtered = [tuple(filter_path_17_node(p)) for p in paths if p]
                if not paths_filtered:
                    continue
                counter = Counter(paths_filtered)
                mode_path, _ = counter.most_common(1)[0]
                patient_paths[pid] = list(mode_path)
            all_model_paths[mname] = patient_paths

        # Collect dataset
        rows = []
        for model_name in model_names:
            n_iter = 10
            rdir = self.results_dir_for_model(base_path, model_name)
            for row in self.collect_rows_for_model(rdir, model_name, n_iter, all_model_paths, model_trend_features):
                rows.append(row)

        if not rows:
            print("No data gathered from the selected models.")
            return None

        ordered_models = list(dict.fromkeys([m for *_rest, m in [(None, None, None, r[3]) for r in rows]]))
        X_list, y_path, y_treat, groups, model_labels = [], [], [], [], []
        for feats, (ps, tm), pid, mname, treat_match, path_match in rows:
            vec = [val for key, val in feats.items()]
            X_list.append(vec)
            y_path.append(path_match)
            y_treat.append(treat_match)
            groups.append(pid)
            model_labels.append(mname)

        X = np.asarray(X_list, dtype=float)
        y_path = np.asarray(y_path, dtype=float)
        y_treat = np.asarray(y_treat, dtype=int)
        groups = np.asarray(groups)

        return X, y_path, y_treat, groups, ordered_models, model_labels

# -------------------------------
# Main
# -------------------------------
if __name__ == "__main__":
    BASE_PATH = Path("./CancerGUIDE")
    ANNOTATIONS_PATH = BASE_PATH / "data" / "benchmarks" / "human_annotations"
    model_names = ["gpt-5", "gpt-5-med", "gpt-4.1", "o4-mini", "o3", "deepseek", "llama", "gpt-5-high"]

    # model_names=model_names[5:6]  # for quick testing
    # Trend features
    model_trend_features= load_trend_data()
    # Load dataset
    analyzer = AccuracyAnalyzer(annotations_path=ANNOTATIONS_PATH)
    X, y_path, y_treat, groups, ordered_models, model_labels = analyzer.build_dataset(BASE_PATH, model_names, model_trend_features)
    BASE_FEATURES = ["path_match_fraction", "treatment_match_fraction", "aggregated"]
    feature_names = BASE_FEATURES + list(model_trend_features['gpt-4.1'].keys()) # + model_onehot_names

    selected_idx=[0,1,2]  # path_match_fraction, treatment_match_fraction, aggregated, gpt-4.1_average_treatment_match
    X_selected = X[:, selected_idx]
    total = np.concatenate([X_selected, y_treat[:, None]], axis=1)
    run_clustering_analysis(X_selected, y_treat=y_treat, n_clusters=2)
