"""
Hierarchical Clustering for Box Embeddings

Implements agglomerative hierarchical clustering based on volume increase
when merging boxes. Uses box union operations to create hierarchical structures.
"""

import argparse
import heapq
import json
import pickle
import sys
from typing import Dict, List, Optional, Set, Tuple

import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from enhanced_cluster_visualizer import \
    visualize_clustering_tree_enhanced  # noqa

# sys.path.insert(0, os.path.join(sys.path[0], ".."))
sys.path.append("..")
from box.box_wrapper import BoxTensor, CenterDeltaBoxTensor
from box.utils import log1mexp
from box_sentence_trainer import MLPHead  # noqa


class Cluster:
    """
    Represents a cluster in the hierarchical clustering tree.

    Attributes:
        cluster_id: Unique identifier for this cluster
        box: BoxTensor representing the bounding box of this cluster
        box_tensor: The underlying box tensor (same as box, stored for convenience)
        level: Merge level (0 for initial leaf clusters, increases with each merge)
        members: Set of original prompt/item indices included in this cluster
        parent: Reference to parent cluster (None if root or not yet merged)
        children: List of child clusters (empty for leaf clusters)
    """

    def __init__(
        self,
        cluster_id: int,
        box: BoxTensor,
        level: int = 0,
        members: Optional[Set[int]] = None,
        parent: Optional["Cluster"] = None,
        children: Optional[List["Cluster"]] = None,
    ):
        self.cluster_id = cluster_id
        self.box = box
        self.box_tensor = box  # Store box tensor explicitly
        self.level = level
        self.members = members if members is not None else set()
        self.parent = parent
        self.children = children if children is not None else []
        self.is_active = True  # Track if this cluster is still available for merging

    def __repr__(self):
        return (
            f"Cluster(id={self.cluster_id}, level={self.level}, members={self.members})"
        )


class HierarchicalClustering:
    """
    Implements hierarchical clustering for box embeddings.

    Algorithm:
    i)   Initialize all clusters C_1, C_2, ..., C_n from input boxes
    ii)  Create priority queue of volume increases for all pairs
    iii) Merge the pair with smallest volume increase
    iv)  Recalculate volume increases with the new cluster
    v)   Repeat until all merged into one cluster
    """

    def __init__(self, boxes: List[BoxTensor]):
        """
        Initialize hierarchical clustering.

        Args:
            boxes: List of BoxTensor objects to cluster
        """
        self.boxes = boxes
        self.n = len(boxes)
        self.clusters: Dict[int, Cluster] = {}
        self.next_cluster_id = 0
        self.merge_history: List[Tuple[int, int, int, int, float]] = []

    def _create_cluster(
        self,
        box: BoxTensor,
        level: int = 0,
        members: Optional[Set[int]] = None,
        children: Optional[List[Cluster]] = None,
    ) -> Cluster:
        """Create a new cluster with a unique ID."""
        cluster = Cluster(
            cluster_id=self.next_cluster_id,
            box=box,
            level=level,
            members=members,
            children=children,
        )
        self.clusters[self.next_cluster_id] = cluster
        self.next_cluster_id += 1
        return cluster

    def _helper_subtract_fixed(self, i, j):
        """helps to calculate log(e^i - e^j)

        Args:
            i: First element
            j: Second element
        """
        subtraction = i + log1mexp(j - i)
        return subtraction

    def _helper_subtract(self, i, j):
        """helps to calculate log(e^i - e^j)

        Args:
            i: First element
            j: Second element
        """
        subtraction = torch.log(torch.exp(i - j) - 1) + j
        return subtraction

    def _calculate_volume_increase(
        self, cluster_i: Cluster, cluster_j: Cluster
    ) -> float:
        """
        Calculate the volume increase if two clusters are merged.

        Volume increase = volume(union) - (volume(i) + volume(j))

        Args:
            cluster_i: First cluster
            cluster_j: Second cluster

        Returns:
            Volume increase as a float
        """
        # Get individual volumes (in log space)
        vol_i = cluster_i.box.log_soft_volume()
        vol_j = cluster_j.box.log_soft_volume()

        # Get union box and its volume
        join_volume = cluster_i.box.join_volume(cluster_j.box, False)

        sum_volume = torch.logaddexp(vol_i, vol_j)
        intersection_volume = cluster_i.box.intersection(
            cluster_j.box, intersection_temp=0.001
        ).log_soft_volume()
        log_union_volume = self._helper_subtract_fixed(sum_volume, intersection_volume)

        volume_increase = self._helper_subtract_fixed(join_volume, log_union_volume)

        return volume_increase.item()

    def _initialize_clusters(self):
        """
        Step i: Initialize all clusters C_1, C_2, ..., C_n

        Each box becomes a cluster at level 0 with a single member.
        """
        for idx, box in enumerate(self.boxes):
            self._create_cluster(box=box, level=0, members={idx})

    def _build_initial_priority_queue(self) -> List[Tuple[float, int, int]]:
        """
        Step ii: Create n×n/2 matrix of volume increases for all pairs.

        Returns a min-heap priority queue with entries (volume_increase, cluster_i_id, cluster_j_id)
        """
        priority_queue = []

        # Get all active cluster IDs
        active_ids = [cid for cid, c in self.clusters.items() if c.is_active]

        # Calculate volume increase for all pairs (upper triangular)
        for i in tqdm(range(len(active_ids))):
            for j in range(i + 1, len(active_ids)):
                cluster_i = self.clusters[active_ids[i]]
                cluster_j = self.clusters[active_ids[j]]

                vol_increase = self._calculate_volume_increase(cluster_i, cluster_j)

                # heapq is a min-heap, so smaller volume increases come first
                heapq.heappush(
                    priority_queue, (vol_increase, active_ids[i], active_ids[j])
                )

        return priority_queue

    def _merge_clusters(
        self, cluster_i: Cluster, cluster_j: Cluster, volume_increase: float
    ) -> Cluster:
        """
        Step iii: Create a new cluster by merging two clusters.

        Args:
            cluster_i: First cluster to merge
            cluster_j: Second cluster to merge
            volume_increase: The volume increase for this merge

        Returns:
            The newly created merged cluster
        """
        # Create union box
        union_box = cluster_i.box.join(cluster_j.box)
        old_volumes = (cluster_i.box.volume, cluster_j.box.volume)

        # New level is max of the two + 1
        new_level = max(cluster_i.level, cluster_j.level) + 1

        # Combine members
        new_members = cluster_i.members | cluster_j.members

        # Create new cluster
        new_cluster = self._create_cluster(
            box=union_box,
            level=new_level,
            members=new_members,
            children=[cluster_i, cluster_j],
        )

        # Update parent pointers
        cluster_i.parent = new_cluster
        cluster_j.parent = new_cluster

        # Mark old clusters as inactive
        cluster_i.is_active = False
        cluster_j.is_active = False

        # Record merge in history
        self.merge_history.append(
            (
                new_level,
                cluster_i.cluster_id,
                cluster_j.cluster_id,
                new_cluster.cluster_id,
                volume_increase,
            )
        )

        return new_cluster

    def _update_priority_queue(
        self, priority_queue: List[Tuple[float, int, int]], new_cluster: Cluster
    ):
        """
        Step iv: Recalculate volume increases with the new cluster.

        Adds entries to the priority queue for the new cluster paired with
        all remaining active clusters.

        Args:
            priority_queue: The min-heap to update
            new_cluster: The newly created cluster
        """
        # Get all active clusters (excluding the new cluster itself)
        active_clusters = [
            c
            for c in self.clusters.values()
            if c.is_active and c.cluster_id != new_cluster.cluster_id
        ]

        # Calculate volume increase for new cluster with each active cluster
        for other_cluster in active_clusters:
            vol_increase = self._calculate_volume_increase(new_cluster, other_cluster)
            heapq.heappush(
                priority_queue,
                (vol_increase, new_cluster.cluster_id, other_cluster.cluster_id),
            )

    def fit(self) -> Dict:
        """
        Perform hierarchical clustering on the input boxes.

        Executes steps i-v of the algorithm and returns the clustering results.

        Returns:
            Dictionary containing:
                - merge_history: List of (level, cluster_i_id, cluster_j_id, new_cluster_id, volume_increase)
                - levels: Dict mapping level -> list of cluster IDs created at that level
                - cluster_members: Dict mapping cluster_id -> set of original prompt indices
                - final_tree: The root cluster of the hierarchical tree
                - all_clusters: Dict of all clusters created (including intermediate ones)
        """
        # Step i: Initialize clusters
        print("Creating initial clusters")
        self._initialize_clusters()

        # Step ii: Build initial priority queue
        print("Building the priority queue")
        priority_queue = self._build_initial_priority_queue()

        # Steps iii-v: Iteratively merge until one cluster remains
        total_merges = self.n - 1  # We need n-1 merges to go from n clusters to 1
        with tqdm(total=total_merges, desc="Clustering", unit="merge") as pbar:
            while len([c for c in self.clusters.values() if c.is_active]) > 1:
                # Find the next valid merge (skip if either cluster is no longer active)
                while priority_queue:
                    volume_increase, cluster_i_id, cluster_j_id = heapq.heappop(
                        priority_queue
                    )

                    # Check if both clusters are still active
                    if (
                        self.clusters[cluster_i_id].is_active
                        and self.clusters[cluster_j_id].is_active
                    ):
                        break
                else:
                    # No more valid merges (shouldn't happen)
                    break

                # Step iii: Merge the clusters with smallest volume increase
                cluster_i = self.clusters[cluster_i_id]
                cluster_j = self.clusters[cluster_j_id]
                new_cluster = self._merge_clusters(
                    cluster_i, cluster_j, volume_increase
                )

                # Step iv: Recalculate with new cluster
                self._update_priority_queue(priority_queue, new_cluster)

                # Update progress bar
                pbar.update(1)

        # Build output structure
        return self._build_output()

    def _build_output(self) -> Dict:
        """
        Create the output structure with merge history, levels, and cluster members.

        Returns:
            Dictionary with clustering results
        """
        # Find the final root cluster (should be the only active one)
        active_clusters = [c for c in self.clusters.values() if c.is_active]
        final_tree = active_clusters[0] if active_clusters else None

        # Group clusters by level
        levels: Dict[int, List[int]] = {}
        for cluster in self.clusters.values():
            if cluster.level not in levels:
                levels[cluster.level] = []
            levels[cluster.level].append(cluster.cluster_id)

        # Create cluster_members mapping
        cluster_members = {
            cluster.cluster_id: cluster.members for cluster in self.clusters.values()
        }

        return {
            "merge_history": self.merge_history,
            "levels": levels,
            "cluster_members": cluster_members,
            "final_tree": final_tree,
            "all_clusters": self.clusters,
        }

    @staticmethod
    def save_result(result: Dict, filepath: str):
        """
        Save clustering result to a file using pickle.

        Args:
            result: Clustering result dictionary from fit()
            filepath: Path to save the result (e.g., 'clustering_result.pkl')
        """
        with open(filepath, "wb") as f:
            pickle.dump(result, f)
        print(f"Clustering result saved to: {filepath}")

    @staticmethod
    def load_result(filepath: str) -> Dict:
        """
        Load clustering result from a file.

        Args:
            filepath: Path to the saved result file

        Returns:
            Clustering result dictionary
        """
        with open(filepath, "rb") as f:
            result = pickle.load(f)
        print(f"Clustering result loaded from: {filepath}")
        return result


def hierarchical_clustering(boxes: List[BoxTensor]) -> Dict:
    """
    Convenience function to perform hierarchical clustering on a list of boxes.

    Args:
        boxes: List of BoxTensor objects to cluster

    Returns:
        Dictionary containing clustering results (see HierarchicalClustering.fit())
    """
    clustering = HierarchicalClustering(boxes)
    return clustering.fit()


def print_cluster_tree(cluster, indent: int = 0, max_members: int = 5):
    """
    Pretty print a cluster tree to console (text-based visualization).

    Args:
        cluster: Root Cluster object to print
        indent: Current indentation level (for recursion)
        max_members: Maximum number of members to show before truncating

    Example:
        >>> result = hierarchical_clustering(boxes)
        >>> print_cluster_tree(result['final_tree'])
    """
    # Indentation
    prefix = "  " * indent

    # Cluster info
    members = sorted(list(cluster.members))
    member_count = len(members)

    # Truncate members if too many
    if member_count <= max_members:
        members_str = f"[{', '.join(map(str, members))}]"
    else:
        members_str = f"[{', '.join(map(str, members[:max_members]))}, ... +{member_count - max_members} more]"

    # Icon for node type
    icon = "🍃" if len(cluster.children) == 0 else "📦"

    # Print this cluster
    print(
        f"{prefix}{icon} Cluster {cluster.cluster_id} (Level {cluster.level}) - "
        f"{member_count} member{'s' if member_count != 1 else ''} {members_str}"
    )

    # Print children recursively
    for child in cluster.children:
        print_cluster_tree(child, indent + 1, max_members)


def _process_infinity_instruct(sample: Dict) -> Dict:
    """Process Infinity-Instruct dataset format."""
    anchor = None
    positive = None

    for conversation in sample["conversations"]:
        if conversation["from"] == "human":
            text = conversation["value"]
            anchor = text
        else:  # assistant response
            text = conversation["value"]
            positive = text

    return {"prompt": anchor}


def load_infinity_instruct_dataset():
    """Load and process the Infinity-Instruct dataset."""
    dataset = load_dataset(
        "BAAI/Infinity-Instruct",
        "0625",
        split="train[:1000]",
    )

    # dataset = dataset.train_test_split(test_size=0.3, shuffle=False)
    dataset = dataset.map(_process_infinity_instruct)

    # Filter and clean
    dataset = dataset.filter(lambda x: x["langdetect"] == "en").remove_columns(
        ["id", "conversations", "label", "langdetect", "source"]
    )

    return [i["prompt"] for i in dataset[:100]]


def load_synthetic_dataset():
    """Load and process the Infinity-Instruct dataset."""

    with open("../total_list.json", "r") as f:
        total_list = json.load(f)

    total_list_flattened = [x for xs in total_list for x in xs]
    return list(set(total_list_flattened[:200]))[:100]


def _change_dataset_format(sample, model_name):
    """Transform UltraFeedback sample format."""
    temp = None
    for i in sample["completions"]:
        if i["model"] == model_name:
            temp = i
            break
    assert temp is not None
    sample["anchor"] = sample["instruction"]
    sample["positive"] = temp["response"]
    sample["score"] = temp["overall_score"]
    return sample


def load_ultrafeedback_dataset(count, model_name):
    """Load and process the Infinity-Instruct dataset."""
    # dataset = dataset.train_test_split(test_size=0.3, shuffle=False)

    ds = load_dataset("openbmb/UltraFeedback", split="train[:]")

    def temp_func(x):
        if model_name in x["models"]:
            for i in x["completions"]:
                if i["model"] == model_name:
                    return True
        else:
            return False

    ds = ds.filter(temp_func)
    ds = ds.map(lambda sample: _change_dataset_format(sample, model_name))
    dataset = ds.remove_columns(
        [
            "source",
            "instruction",
            "models",
            "completions",
            "correct_answers",
            "incorrect_answers",
        ]
    )

    # print(dataset[:10])
    return [i for i in dataset["anchor"][:count]], [i for i in dataset["score"][:count]]


def load_wildbench_dataset(count):
    """Load and process the Infinity-Instruct dataset."""
    dataset = load_dataset(
        "allenai/WildBench",
        "v2",
        split=f"test[:{count}]",
    )

    dataset = dataset.filter(lambda example: len(example["conversation_input"]) == 1)
    prompts = [
        example["conversation_input"][-1].get("content", "") for example in dataset
    ]
    print(f"Loaded {len(prompts)} prompts")

    return prompts


def main():
    parser = argparse.ArgumentParser(
        description="Hierarchical Clustering for Box Embeddings"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        # default="alpaca-7b",
        # choices=["alpaca-7b", "llama-2-13b-chat", "vicuna-33b"],
        required=True,
        help="Model name for UltraFeedback dataset (default: alpaca-7b)",
    )
    args = parser.parse_args()
    model_name = args.model_name

    # Load from HuggingFace
    print("=" * 60)
    print("Hierarchical Clustering for WildBench Dataset")
    print("=" * 60)
    name = "ultrafeedback"
    count = 500

    print(f"\n[1/4] Loading {name} dataset...")

    if name == "ultrafeedback":
        prompts, scores = load_ultrafeedback_dataset(count, model_name)
    elif name == "wildbench":
        prompts = load_wildbench_dataset(count)
    elif name == "synthetic":
        prompts = load_synthetic_dataset()

    print(f"    Loaded {len(prompts)} prompts")

    print("\n[2/4] Encoding prompts with box embeddings...")

    path = "../outputs/models/pretrained_ds50000_box_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksTrue_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0/"

    # path = "../outputs/models/pretrained_ds50000_box_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksFalse_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0/"
    # path = "../outputs/models/pretrained_ds50000_vector_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksTrue_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0/"
    model = SentenceTransformer(path)

    corpus_embeddings = model.encode(
        prompts,
        normalize_embeddings=True,
        show_progress_bar=True,
        batch_size=32,  # Add batch size for memory efficiency
        convert_to_tensor=True,
    )

    corpus_box_embeddings = [
        CenterDeltaBoxTensor.from_split(sentence_embedding)
        for sentence_embedding in corpus_embeddings
    ]

    print(f"    Created {len(corpus_box_embeddings)} box embeddings")
    result = hierarchical_clustering(corpus_box_embeddings)

    # Print text-based tree
    print("\n    Cluster Tree (first few levels):")
    print_cluster_tree(result["final_tree"], max_members=3)

    # Save clustering result
    print("\n    Saving clustering result...")
    HierarchicalClustering.save_result(
        result, f"./all_model_hierarchy/{name}_clustering_box_{count}_{model_name}.pkl"
    )

    print("\n[4/4] Generating HTML visualization...")

    output_file = f"./all_model_hierarchy/{name}_cluster_visualization_normal_{count}_{model_name}.html"

    if name == "ultrafeedback":
        visualize_clustering_tree_enhanced(
            clustering_result=result,
            output_path=output_file,
            title="UltraFeedback Hierarchical Clustering",
            item_names=prompts,  # Full, untruncated prompts
            scores=scores,  # Pass scores for visualization
        )
    else:
        visualize_clustering_tree_enhanced(
            clustering_result=result,
            output_path=output_file,
            title="WildBench Hierarchical Clustering",
            item_names=prompts,  # Full, untruncated prompts
        )

    print("\n" + "=" * 60)
    print("✓ Complete!")
    print(f"  Open '{output_file}' in your browser to view the interactive tree.")
    print("=" * 60)


if __name__ == "__main__":
    main()
