import json
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple, Set

import numpy as np
import tqdm
from sklearn.cluster import AgglomerativeClustering


def _similarity_to_distance(S: np.ndarray) -> Tuple[np.ndarray, float]:
    S = 0.5 * (S + S.T)
    smax = 100.0
    D = 1.0 - (S / smax)
    D[D < 0] = 0.0
    np.fill_diagonal(D, 0.0)
    return D, smax


def _cluster_from_matrix(
    S: np.ndarray,
    *,
    threshold_similarity_0to1: float,
    linkage: str = "average",
    compute_full_tree: str | bool = "auto",
) -> Tuple[List[int], List[List[int]], float, float]:
    D, smax = _similarity_to_distance(S)
    dist_th = 1.0 - float(threshold_similarity_0to1)

    model = AgglomerativeClustering(
        metric="precomputed",
        linkage=linkage,
        n_clusters=None,
        distance_threshold=dist_th,
        compute_full_tree=compute_full_tree,
    )
    labels = model.fit_predict(D)
    clusters = [np.where(labels == k)[0].tolist() for k in np.unique(labels)]
    return labels.tolist(), clusters, dist_th, smax


def _load_matrix_and_indices(jsonl_path: str, idx: int) -> Tuple[Optional[np.ndarray], List[int]]:
    p = Path(jsonl_path)
    if not p.exists():
        return None, []
    
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            if obj.get("idx") == idx:
                matrix = obj.get("matrix")
                indices = obj.get("indices", [])
                if matrix is not None:
                    return np.asarray(matrix, dtype=np.float32), indices
                return None, indices
    return None, []


def _cluster_with_indices(
    matrix: np.ndarray,
    indices_1b: List[int],
    threshold_similarity_0to1: float,
    linkage: str,
    compute_full_tree: str | bool,
) -> Tuple[List[List[int]], Dict]:
    if matrix.shape[0] < 2:
        return [indices_1b], {
            "n_clusters": 1,
            "labels": [0] if indices_1b else [],
            "distance_threshold": 1.0 - threshold_similarity_0to1,
            "smax": 100.0,
            "n_samples": len(indices_1b)
        }
    labels, clusters_local, dist_th, smax = _cluster_from_matrix(
        matrix,
        threshold_similarity_0to1=threshold_similarity_0to1,
        linkage=linkage,
        compute_full_tree=compute_full_tree,
    )
    
    clusters_original = []
    for cluster_local in clusters_local:
        cluster_original = [indices_1b[i] for i in cluster_local]
        clusters_original.append(sorted(cluster_original))
    
    info = {
        "n_clusters": len(clusters_original),
        "labels": labels,
        "distance_threshold": dist_th,
        "smax": smax,
        "n_samples": len(indices_1b)
    }
    
    return clusters_original, info


def save_clusters_split_by_correct_wrong(
    all_jsonl_path: str,
    correct_jsonl_path: Optional[str],
    wrong_jsonl_path: Optional[str],
    *,
    similarity_threshold_0to100: float,
    linkage: str = "average",
    compute_full_tree: str | bool = "auto",
) -> Path:
    t = float(similarity_threshold_0to100) / 100.0

    p = Path(all_jsonl_path)
    out_path_str = p.with_suffix("").as_posix() + "_cluster_cw.jsonl"
    out_path_str = out_path_str.replace("/sim_logs/", "/sim_clusters/")
    out_path = Path(out_path_str)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with Path(all_jsonl_path).open("r", encoding="utf-8") as fin:
        total_lines = sum(1 for line in fin if line.strip())

    with (
        Path(all_jsonl_path).open("r", encoding="utf-8") as fin,
        out_path.open("w", encoding="utf-8") as fout,
    ):
        for line in tqdm.tqdm(fin, total=total_lines, desc="Processing samples"):
            if not line.strip():
                continue
            obj = json.loads(line)
            idx = obj.get("idx")
            M = obj.get("matrix")
            if M is None or idx is None:
                continue

            S = np.asarray(M, dtype=np.float32)
            if S.ndim != 2 or S.shape[0] != S.shape[1] or S.shape[0] < 2:
                continue

            n = int(S.shape[0])
            labels_all, clusters_all_0b, dist_th_all, smax_all = _cluster_from_matrix(
                S,
                threshold_similarity_0to1=t,
                linkage=linkage,
                compute_full_tree=compute_full_tree,
            )
            clusters_all_1b = [[int(i + 1) for i in members] for members in clusters_all_0b]
            correct_matrix, correct_indices = None, []
            wrong_matrix, wrong_indices = None, []
            
            if correct_jsonl_path:
                correct_matrix, correct_indices = _load_matrix_and_indices(correct_jsonl_path, idx)
            
            if wrong_jsonl_path:
                wrong_matrix, wrong_indices = _load_matrix_and_indices(wrong_jsonl_path, idx)
            corr_1b = sorted([i for i in correct_indices if 1 <= i <= n])
            wrng_1b = sorted([i for i in wrong_indices if 1 <= i <= n])
            corr_set = set(corr_1b)
            wrng_set = set(wrng_1b)
            overlap = corr_set & wrng_set
            if overlap:
                wrng_set = wrng_set - overlap
                wrng_1b = sorted(wrng_set)
            all_indices = set(range(1, n + 1))
            unlabeled_1b = sorted(all_indices - corr_set - wrng_set)
            correct_clusters = []
            wrong_clusters = []
            unlabeled_clusters = []
            
            correct_info = {}
            wrong_info = {}
            unlabeled_info = {}

            if corr_1b:
                if correct_matrix is not None and correct_matrix.shape[0] == len(corr_1b):
                    correct_clusters, correct_info = _cluster_with_indices(
                        correct_matrix, corr_1b, t, linkage, compute_full_tree
                    )
                else:
                    indices_0b = [i - 1 for i in corr_1b]
                    submatrix = S[np.ix_(indices_0b, indices_0b)]
                    correct_clusters, correct_info = _cluster_with_indices(
                        submatrix, corr_1b, t, linkage, compute_full_tree
                    )
            
            if wrng_1b:
                if wrong_matrix is not None and wrong_matrix.shape[0] == len(wrng_1b):
                    wrong_clusters, wrong_info = _cluster_with_indices(
                        wrong_matrix, wrng_1b, t, linkage, compute_full_tree
                    )
                else:
                    indices_0b = [i - 1 for i in wrng_1b]
                    submatrix = S[np.ix_(indices_0b, indices_0b)]
                    wrong_clusters, wrong_info = _cluster_with_indices(
                        submatrix, wrng_1b, t, linkage, compute_full_tree
                    )
            if unlabeled_1b:
                indices_0b = [i - 1 for i in unlabeled_1b]
                submatrix = S[np.ix_(indices_0b, indices_0b)]
                unlabeled_clusters, unlabeled_info = _cluster_with_indices(
                    submatrix, unlabeled_1b, t, linkage, compute_full_tree
                )

            final_clusters_labeled_1based = correct_clusters + wrong_clusters
            final_clusters_all_1based = correct_clusters + wrong_clusters + unlabeled_clusters

            num_correct_clusters = len(correct_clusters)
            num_wrong_clusters = len(wrong_clusters)
            num_unlabeled_clusters = len(unlabeled_clusters)

            final_n_clusters = num_correct_clusters + num_wrong_clusters
            final_n_clusters_including_unlabeled = len(final_clusters_all_1based)
            assignment_all_1b = [None] * n
            for cid, members_1b in enumerate(clusters_all_1b):
                for m in members_1b:
                    assignment_all_1b[m - 1] = cid

            result = {
                "idx": idx,
                "n_samples": n,
                "n_clusters_all": len(clusters_all_1b),
                "labels_all_0based": labels_all,
                "clusters_all_1based": clusters_all_1b,
                "cluster_assignment_all_1based": assignment_all_1b,
                "correct_clusters_1based": correct_clusters,
                "wrong_clusters_1based": wrong_clusters,
                "unlabeled_clusters_1based": unlabeled_clusters,
                "num_correct_clusters": num_correct_clusters,
                "num_wrong_clusters": num_wrong_clusters,
                "num_unlabeled_clusters": num_unlabeled_clusters,
                "final_n_clusters": final_n_clusters,
                "final_clusters_1based": final_clusters_labeled_1based,
                "final_n_clusters_including_unlabeled": final_n_clusters_including_unlabeled,
                "final_clusters_all_1based": final_clusters_all_1based,
                "linkage": linkage,
                "similarity_threshold": t,
                "distance_threshold_all": dist_th_all,
                "smax_all": smax_all,
                "correct_clustering_info": correct_info,
                "wrong_clustering_info": wrong_info,
                "unlabeled_clustering_info": unlabeled_info,
                "correct_indices": corr_1b,
                "wrong_indices": wrng_1b,
                "unlabeled_indices": unlabeled_1b,
            }
            fout.write(json.dumps(result, ensure_ascii=False) + "\n")

    return out_path


if __name__ == "__main__":
    import glob
    from pathlib import Path

    # MODELS = [
    #     # "Qwen2.5-1.5B-Instruct_Long_CoT",
    #     # "Qwen2.5-1.5B-Instruct_olympiads_qwq_math_checkpoint-7500",
    #     # "Qwen2.5-1.5B-Instruct_orca-math_checkpoint-7500",
    #     # "Qwen-2.5-1.5B-SimpleRL-Zoo",
    #     # "Qwen2.5-3B-Instruct_Long_CoT",
    #     # "Qwen2.5-3B-Instruct_olympiads_qwq_math_checkpoint-7500",
    #     # "Qwen2.5-3B-Instruct_orca-math_checkpoint-7500",
    #     # "Qwen-2.5-3B-SimpleRL-Zoo",
    #     # "Qwen2.5-7B-Instruct_Long_CoT",
    #     # "Qwen2.5-7B-Instruct_olympiads_qwq_math_checkpoint-7500",
    #     # "Qwen2.5-7B-Instruct_orca-math_checkpoint-7500",
    #     # "Qwen-2.5-7B-SimpleRL-Zoo",
    #     # "Qwen2.5-1.5B-Instruct",
    #     # "Qwen2.5-3B-Instruct",
    #     # "Qwen2.5-7B-Instruct",
    #     # "DeepSeek-R1-Distill-Qwen-1.5B",
    #     # "Nemotron-Research-Reasoning-Qwen-1.5B",
    #     # "Qwen-2.5-Math-7B-SimpleRL-Zero",
    #     "Qwen-2.5-Math-7B",
    #     "Qwen-2.5-7B-SimpleRL-Zoo",
    #     # "Qwen2.5-Math-7B-Instruct",
    #     # "Llama-3.1-8B-Instruct",
    #     # "Llama-3.1-8B-Instruct_Long_CoT",
    #     # "Llama-3.1-8B-SimpleRL-Zoo",
    #     "Qwen2.5-7B",
    #     "DeepSeek-R1-Distill-Qwen-7B",
    #     "Qwen2.5-Math-7B-Oat-Zero",
    #     "AceReason-Nemotron-7B",
    #     "AceReason-Nemotron-1.1-7B",
    #     "Qwen2.5-14B",
    #     "DeepSeek-R1-Distill-Qwen-14B",
    #     "Qwen-2.5-14B-SimpleRL-Zoo",
    #     "AceReason-Nemotron-14B",
    #     # "Qwen2.5-32B",
    #     # "DeepSeek-R1-Distill-Qwen-32B",
    #     # "Qwen-2.5-32B-SimpleRL-Zoo",
    # ]
    # MODELS = [
    #     "Qwen2.5-14B",
    #     "Qwen-2.5-14B-SimpleRL-Zoo",
    #     "DeepSeek-R1-Distill-Qwen-14B",
    #     "AceReason-Nemotron-14B",
    # ]
    # MODELS=[
    #     "Qwen2.5-Math-1.5B",
    #     "Qwen2.5-Math-1.5B-Oat-Zero",
    #     "DeepSeek-R1-Distill-Qwen-1.5B",
    #     "Nemotron-Research-Reasoning-Qwen-1.5B",
    # ]
    # MODELS = [
    #     # "Qwen2.5-7B",
    #     # "Qwen-2.5-7B-SimpleRL-Zoo",
    #     # "AceReason-Nemotron-1.1-7B",
    #     # "Llama-3.1-8B",
    #     # "Llama-3.1-8B-SimpleRL-Zoo",
    #     "DeepSeek-R1-Distill-Llama-8B",
    # ]
    MODELS = [
        "Qwen2.5-Math-7B",
        "Qwen2.5-Math-7B-Oat-Zero",
        "DeepSeek-R1-Distill-Qwen-7B",
        "AceReason-Nemotron-7B",
    ]

    DATAS = [
        # "minerva_math",
        # "olympiadbench",
        # "math500",
        "aime24",
        # "aime25",
        # "amc23",
    ]

    base_dir = "XXXX"

    for model in MODELS:
        for data in DATAS:
            pattern = f"{base_dir}/{model}_{data}_*_sim_matrix_all.jsonl"
            matching_files = glob.glob(pattern)

            for all_jsonl_path in matching_files:
                all_path = Path(all_jsonl_path)
                if not all_path.exists():
                    continue
                correct_jsonl_path = all_jsonl_path.replace(
                    "_all.jsonl", "_correct.jsonl"
                )
                wrong_jsonl_path = all_jsonl_path.replace("_all.jsonl", "_wrong.jsonl")
                correct_exists = (
                    Path(correct_jsonl_path).exists() if correct_jsonl_path else False
                )
                wrong_exists = (
                    Path(wrong_jsonl_path).exists() if wrong_jsonl_path else False
                )
                try:
                    out = save_clusters_split_by_correct_wrong(
                        all_jsonl_path=all_jsonl_path,
                        correct_jsonl_path=correct_jsonl_path
                        if correct_exists
                        else None,
                        wrong_jsonl_path=wrong_jsonl_path if wrong_exists else None,
                        similarity_threshold_0to100=50,
                        linkage="average",
                        compute_full_tree="auto",
                    )
                    print(f"  Written: {out}")
                except Exception as e:
                    print(f"  Error: {e}")
                print()
