"""
Vector-Based Hierarchical Clustering

Implements hierarchical clustering using standard vector embeddings with scipy's
linkage method. Uses Euclidean distance and single linkage for similarity measurement.
"""

import argparse
import json
import os
import pickle
import sys
from typing import Dict, List, Optional, Set, Tuple

sys.path.insert(1, os.path.join(sys.path[0], ".."))
import numpy as np
from box_sentence_trainer import MLPHead, VectorEntailmentHead  # noqa
from datasets import load_dataset
from scipy.cluster.hierarchy import linkage
from sentence_transformers import SentenceTransformer

from enhanced_cluster_visualizer import visualize_clustering_tree_enhanced


class Cluster:
    """
    Represents a cluster in the hierarchical clustering tree.

    Attributes:
        cluster_id: Unique identifier for this cluster
        box: Optional box tensor (None for vector clustering)
        level: Merge level (0 for initial leaf clusters, increases with each merge)
        members: Set of original prompt/item indices included in this cluster
        parent: Reference to parent cluster (None if root or not yet merged)
        children: List of child clusters (empty for leaf clusters)
    """

    def __init__(
        self,
        cluster_id: int,
        box=None,
        level: int = 0,
        members: Optional[Set[int]] = None,
        parent: Optional["Cluster"] = None,
        children: Optional[List["Cluster"]] = None,
    ):
        self.cluster_id = cluster_id
        self.box = box
        self.box_tensor = box  # Store box tensor explicitly
        self.level = level
        self.members = members if members is not None else set()
        self.parent = parent
        self.children = children if children is not None else []
        self.is_active = True  # Track if this cluster is still available for merging

    def __repr__(self):
        return (
            f"Cluster(id={self.cluster_id}, level={self.level}, members={self.members})"
        )


class VectorHierarchicalClustering:
    """
    Implements hierarchical clustering using scipy's linkage method.

    Algorithm:
    1. Initialize leaf clusters for each sample
    2. Compute linkage matrix using scipy (single linkage + Euclidean)
    3. Convert linkage matrix to Cluster tree structure
    4. Return results compatible with existing visualizer
    """

    def __init__(self, vectors: np.ndarray):
        """
        Initialize vector-based hierarchical clustering.

        Args:
            vectors: numpy array of shape (n_samples, embedding_dim)
        """
        self.vectors = vectors
        self.n = len(vectors)
        self.clusters: Dict[int, Cluster] = {}
        self.next_cluster_id = 0
        self.merge_history: List[Tuple[int, int, int, int, float]] = []

    def _initialize_leaf_clusters(self):
        """Create leaf clusters (level 0) for each sample."""
        for idx in range(self.n):
            cluster = Cluster(
                cluster_id=self.next_cluster_id,
                box=None,
                level=0,
                members={idx},
                children=[],
            )
            self.clusters[self.next_cluster_id] = cluster
            self.next_cluster_id += 1

    def _build_tree_from_linkage(self, linkage_matrix: np.ndarray):
        """
        Convert scipy linkage matrix to Cluster tree.

        Key insight: scipy IDs need mapping to our cluster IDs
        - scipy IDs 0..n-1 map to our cluster IDs 0..n-1 (leaf clusters)
        - scipy IDs n..2n-2 map to cluster IDs created during merging

        Args:
            linkage_matrix: scipy linkage matrix of shape (n-1, 4)
                Each row: [cluster_i, cluster_j, distance, n_samples]
        """
        n = self.n

        # Create mapping from scipy cluster IDs to our cluster IDs
        # Initially, scipy IDs 0..n-1 map directly to our cluster IDs 0..n-1
        scipy_to_our_id = {i: i for i in range(n)}

        for merge_idx, (scipy_i, scipy_j, dist, n_samples) in enumerate(linkage_matrix):
            scipy_i, scipy_j = int(scipy_i), int(scipy_j)

            # Get our cluster objects
            our_i = scipy_to_our_id[scipy_i]
            our_j = scipy_to_our_id[scipy_j]
            cluster_i = self.clusters[our_i]
            cluster_j = self.clusters[our_j]

            # Create merged cluster
            new_level = max(cluster_i.level, cluster_j.level) + 1
            new_members = cluster_i.members | cluster_j.members

            new_cluster = Cluster(
                cluster_id=self.next_cluster_id,
                box=None,
                level=new_level,
                members=new_members,
                children=[cluster_i, cluster_j],
            )

            self.clusters[self.next_cluster_id] = new_cluster

            # Update relationships
            cluster_i.parent = new_cluster
            cluster_j.parent = new_cluster
            cluster_i.is_active = False
            cluster_j.is_active = False

            # Map scipy's new cluster ID (n + merge_idx) to our cluster ID
            scipy_to_our_id[n + merge_idx] = self.next_cluster_id

            # Record merge in history
            self.merge_history.append(
                (
                    new_level,
                    cluster_i.cluster_id,
                    cluster_j.cluster_id,
                    new_cluster.cluster_id,
                    float(dist),
                )
            )

            self.next_cluster_id += 1

    def _build_output(self) -> Dict:
        """
        Build output structure compatible with visualizer.

        Returns:
            Dictionary containing:
                - merge_history: List of merge operations
                - levels: Dict mapping level -> cluster IDs
                - cluster_members: Dict mapping cluster_id -> member set
                - final_tree: Root cluster
                - all_clusters: Dict of all clusters
        """
        # Find final root cluster (only active one)
        active_clusters = [c for c in self.clusters.values() if c.is_active]
        final_tree = active_clusters[0] if active_clusters else None

        # Group clusters by level
        levels = {}
        for cluster in self.clusters.values():
            if cluster.level not in levels:
                levels[cluster.level] = []
            levels[cluster.level].append(cluster.cluster_id)

        # Create cluster_members mapping
        cluster_members = {
            cluster.cluster_id: cluster.members for cluster in self.clusters.values()
        }

        return {
            "merge_history": self.merge_history,
            "levels": levels,
            "cluster_members": cluster_members,
            "final_tree": final_tree,
            "all_clusters": self.clusters,
        }

    def fit(self) -> Dict:
        """
        Perform hierarchical clustering using scipy linkage.

        Returns:
            Dictionary with clustering results compatible with visualizer
        """
        # Step 1: Initialize leaf clusters
        self._initialize_leaf_clusters()

        # Step 2: Compute linkage matrix using scipy
        print("    Computing linkage matrix...")
        linkage_matrix = linkage(self.vectors, method="ward", metric="euclidean")

        # Step 3: Convert linkage matrix to Cluster tree
        print("    Building cluster tree...")
        self._build_tree_from_linkage(linkage_matrix)

        # Step 4: Build output structure
        return self._build_output()

    @staticmethod
    def save_result(result: Dict, filepath: str):
        """
        Save clustering result to a file using pickle.

        Args:
            result: Clustering result dictionary from fit()
            filepath: Path to save the result (e.g., 'clustering_result.pkl')
        """
        with open(filepath, "wb") as f:
            pickle.dump(result, f)
        print(f"    Clustering result saved to: {filepath}")

    @staticmethod
    def load_result(filepath: str) -> Dict:
        """
        Load clustering result from a file.

        Args:
            filepath: Path to the saved result file

        Returns:
            Clustering result dictionary
        """
        with open(filepath, "rb") as f:
            result = pickle.load(f)
        print(f"    Clustering result loaded from: {filepath}")
        return result


def _change_dataset_format(sample, model_name):
    """Transform UltraFeedback sample format."""
    temp = None
    for i in sample["completions"]:
        if i["model"] == model_name:
            temp = i
            break
    assert temp is not None
    sample["anchor"] = sample["instruction"]
    sample["positive"] = temp["response"]
    sample["score"] = temp["overall_score"]
    return sample


def load_ultrafeedback_dataset(size, model_name):
    """Load and process the UltraFeedback dataset."""
    ds = load_dataset("openbmb/UltraFeedback", split="train[:]")

    def temp_func(x):
        if model_name in x["models"]:
            for i in x["completions"]:
                if i["model"] == model_name:
                    return True
        else:
            return False

    ds = ds.filter(temp_func)
    ds = ds.map(lambda sample: _change_dataset_format(sample, model_name))
    dataset = ds.remove_columns(
        [
            "source",
            "instruction",
            "models",
            "completions",
            "correct_answers",
            "incorrect_answers",
        ]
    )

    print(f"Loaded {len(dataset[:size]['anchor'])} prompts")
    return [i for i in dataset["anchor"][:size]], [i for i in dataset["score"][:size]]


def load_wildbench_dataset(count):
    """Load and process the Infinity-Instruct dataset."""
    dataset = load_dataset(
        "allenai/WildBench",
        "v2",
        split=f"test[:{count}]",
    )

    dataset = dataset.filter(lambda example: len(example["conversation_input"]) == 1)
    prompts = [
        example["conversation_input"][-1].get("content", "") for example in dataset
    ]
    print(f"Loaded {len(prompts)} prompts")

    return prompts


def load_synthetic_dataset():
    """Load and process the Infinity-Instruct dataset."""

    with open("../total_list.json", "r") as f:
        total_list = json.load(f)

    total_list_flattened = [x for xs in total_list for x in xs]
    return list(set(total_list_flattened[:200]))[:100]


def main():
    parser = argparse.ArgumentParser(
        description="Vector-Based Hierarchical Clustering"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
        # default="alpaca-7b",
        # choices=["alpaca-7b", "llama-2-13b-chat", "vicuna-33b"],
        help="Model name for UltraFeedback dataset (default: alpaca-7b)",
    )
    args = parser.parse_args()
    model_name = args.model_name

    print("=" * 60)
    print("Vector-Based Hierarchical Clustering")
    print("Using: Single Linkage + Euclidean Distance")
    print("=" * 60)

    size = 500

    name = "ultrafeedback"
    count = 500

    print(f"\n[1/4] Loading {name} dataset...")

    if name == "ultrafeedback":
        prompts, scores = load_ultrafeedback_dataset(count, model_name)
    elif name == "wildbench":
        prompts = load_wildbench_dataset(count)
    elif name == "synthetic":
        prompts = load_synthetic_dataset()

    # [1/4] Load dataset
    print("\n[1/4] Loading dataset...")
    # prompts, scores = load_ultrafeedback_dataset(size)

    # [2/4] Encode with SentenceTransformer
    print("\n[2/4] Encoding prompts with SentenceTransformer...")
    # model = SentenceTransformer(
    #     "../outputs/models/pretrained_ds50000_vector_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksTrue_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0"
    # )

    model = SentenceTransformer(
        "../outputs/models/pretrained_ds50000_vector_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksFalse_none_grad_norm_1.0/"
    )

    corpus_embeddings = model.encode(
        prompts,
        normalize_embeddings=True,
        show_progress_bar=True,
        batch_size=32,
        convert_to_tensor=False,  # Return numpy array
    )

    print(f"    Created {len(corpus_embeddings)} vector embeddings")
    print(f"    Embedding dimension: {corpus_embeddings.shape[1]}")

    # [3/4] Perform clustering
    print("\n[3/4] Performing hierarchical clustering...")
    clustering = VectorHierarchicalClustering(corpus_embeddings)
    result = clustering.fit()

    print(f"    Total merges: {len(result['merge_history'])}")
    print(f"    Total levels: {len(result['levels'])}")

    # Save clustering result
    print("\n    Saving clustering result...")
    VectorHierarchicalClustering.save_result(
        result, f"./all_model_hierarchy/{name}_clustering_vector_{size}_{model_name}_no_entailment.pkl"
    )

    # [4/4] Visualize
    print("\n[4/4] Generating HTML visualization...")
    output_file = f"./all_model_hierarchy/{name}_cluster_visualization_vector_{size}_{model_name}_no_entailment.html"

    visualize_clustering_tree_enhanced(
        clustering_result=result,
        output_path=output_file,
        title="Vector-Based Hierarchical Clustering (Single Linkage + Euclidean)",
        item_names=prompts,
        scores=scores,
    )

    print("\n" + "=" * 60)
    print("✓ Complete!")
    print(f"  Open '{output_file}' in your browser to view the interactive tree.")
    print("=" * 60)


if __name__ == "__main__":
    main()
