"""
Dataset Mix Visualization

This script creates HTML visualizations of box embeddings from multiple datasets,
with each dataset identified by a different color. It supports dynamically adding
and combining datasets from various sources.
"""

import argparse
import json
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import torch
from datasets import load_dataset, load_from_disk
from sentence_transformers import SentenceTransformer

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from box_viz_utils import (BaseBoxEmbeddingOptimizer, BaseConfig,
                           BaseDatasetLoader, CorrelationCalculator,
                           HTMLGenerator)

from box.box_wrapper import CenterDeltaBoxTensor, CenterScalarDeltaBoxTensor
from box_sentence_trainer import MLPHead  # noqa
from box_similarity import similarity_function, similarity_function_entailment


@dataclass
class DatasetEntry:
    """Normalized dataset entry with consistent structure."""

    text: str
    dataset: str  # Must match HTML legend: "UltraFeedback", "wild_chat", "alpaca_eval", "other"
    score: float = 0.0
    metadata: Dict = field(default_factory=dict)


@dataclass
class DatasetMixConfig(BaseConfig):
    """Configuration class for dataset mix visualization."""

    model_path: str = "./outputs/models/pretrained_ds50000_box_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksTrue_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0/"
    dataset_sizes: Dict[str, int] = field(default_factory=dict)


class DatasetMixLoader(BaseDatasetLoader):
    """Handles loading datasets with normalization to consistent format."""

    # TODO: Make this stuff in the common file
    def load_ultrafeedback_dataset(self) -> List[Dict]:
        """Load and preprocess UltraFeedback dataset."""
        ds = load_dataset("openbmb/UltraFeedback", split="train")
        ds = ds.filter(lambda x: self.config.required_model in x["models"])
        ds = ds.map(self._change_dataset_format)
        train_dataset = ds.remove_columns(
            [
                "source",
                "instruction",
                "models",
                "completions",
                "correct_answers",
                "incorrect_answers",
            ]
        )
        return train_dataset

    def load_wildchat_dataset(self) -> List[str]:
        """Load and preprocess WildChat dataset."""
        wild_chat_dataset = (
            load_dataset("allenai/WildChat", split="train[:10000]")
            .filter(lambda x: x["turn"] == 1 and x["language"] == "English")
            .map(self._preprocess_wild_chat)
        )
        return [i["instruction"] for i in wild_chat_dataset]

    def load_hierarchical_dataset(self) -> List[str]:
        """Load hierarchical dataset from total_list.json."""
        with open("total_list.json") as fp:
            general_order_dataset = json.load(fp)

        # Flatten the hierarchical structure
        general_order_dataset_flattened = [
            i for sublist in general_order_dataset for i in sublist
        ]

        return general_order_dataset_flattened[
            : self.config.dataset_sizes["hierarchical"]
        ]

    def load_entailment_dataset(self):
        """Load entailment dataset."""
        return load_from_disk("./data/raw/datasets/new_entailment_dataset_with_sister/")

    def load_wildbench_dataset(self):
        self.dataset = load_dataset(
            "allenai/WildBench",
            "v2",
            split="test",
        )

        self.dataset = self.dataset.filter(
            lambda example: len(example["conversation_input"]) == 1
        )

        # Return both content and session_id to preserve mapping
        return [
            {
                "content": i["conversation_input"][0]["content"],
                "session_id": i["session_id"],
            }
            for i in self.dataset
        ]

    def load_and_normalize_ultrafeedback(self, size: int = 100) -> List[DatasetEntry]:
        """Load UltraFeedback dataset and normalize to DatasetEntry format."""
        raw_data = self.load_ultrafeedback_dataset()
        entries = []

        for item in list(raw_data)[:size]:
            entries.append(
                DatasetEntry(
                    text=item["anchor"],
                    dataset="UltraFeedback",
                    score=item.get("score", 0.0),
                    metadata={"positive": item.get("positive", "")},
                )
            )

        print(f"Loaded {len(entries)} entries from UltraFeedback dataset")
        return entries

    def load_and_normalize_wildchat(self, size: int = 100) -> List[DatasetEntry]:
        """Load WildChat dataset and normalize to DatasetEntry format."""
        raw_data = self.load_wildchat_dataset()
        entries = []

        for text in raw_data[:size]:
            entries.append(DatasetEntry(text=text, dataset="wild_chat", score=0.0))

        print(f"Loaded {len(entries)} entries from WildChat dataset")
        return entries

    def load_and_normalize_hierarchical(self, size: int = 100) -> List[DatasetEntry]:
        """Load hierarchical dataset and normalize to DatasetEntry format."""
        raw_data = self.load_hierarchical_dataset()
        entries = []

        for text in raw_data[:size]:
            entries.append(
                DatasetEntry(
                    text=text,
                    dataset="hierarchical",  # Map to alpaca_eval color in HTML legend
                    score=0.0,
                )
            )

        print(f"Loaded {len(entries)} entries from Hierarchical dataset")
        return entries

    def load_and_normalize_wildbench(self, size: int = 100) -> List[DatasetEntry]:
        """Load WildBench dataset and normalize to DatasetEntry format."""
        raw_data = self.load_wildbench_dataset()
        entries = []

        for item in raw_data[:size]:
            entries.append(
                DatasetEntry(
                    text=item["content"],
                    dataset="wildbench",  # Map to "other" color in HTML legend
                    score=0.0,
                    metadata={"session_id": item.get("session_id", "")},
                )
            )

        print(f"Loaded {len(entries)} entries from WildBench dataset")
        return entries


class DatasetMixOptimizer(BaseBoxEmbeddingOptimizer):
    """Handles box embedding optimization for mixed datasets."""

    pass


class DatasetMixVisualizer:
    """Main class that orchestrates multi-dataset visualization with color coding."""

    def __init__(self, config: DatasetMixConfig):
        self.config = config
        random.seed(config.random_seed)
        torch.manual_seed(config.random_seed)

        config.num_epochs = 5000

        self.data_loader = DatasetMixLoader(config)
        self.optimizer = DatasetMixOptimizer(config)
        self.correlation_calc = CorrelationCalculator()
        self.datasets: List[DatasetEntry] = []

    def add_ultrafeedback_dataset(self, size: int = 100):
        """Add UltraFeedback dataset to the mix."""
        entries = self.data_loader.load_and_normalize_ultrafeedback(size)
        self.datasets.extend(entries)
        return self

    def add_wildchat_dataset(self, size: int = 100):
        """Add WildChat dataset to the mix."""
        entries = self.data_loader.load_and_normalize_wildchat(size)
        self.datasets.extend(entries)
        return self

    def add_hierarchical_dataset(self, size: int = 100):
        """Add hierarchical dataset to the mix."""
        entries = self.data_loader.load_and_normalize_hierarchical(size)
        self.datasets.extend(entries)
        return self

    def add_wildbench_dataset(self, size: int = 100):
        """Add WildBench dataset to the mix."""
        entries = self.data_loader.load_and_normalize_wildbench(size)
        self.datasets.extend(entries)
        return self

    def get_combined_data(self) -> Tuple[List[str], List[str], List[float]]:
        """
        Get combined dataset ready for visualization.

        Returns:
            Tuple of (texts, dataset_labels, scores)
        """
        if not self.datasets:
            raise ValueError("No datasets added! Use add_*_dataset() methods first.")

        texts = [entry.text for entry in self.datasets]
        dataset_labels = [entry.dataset for entry in self.datasets]
        scores = [entry.score for entry in self.datasets]

        print("\nDataset composition:")
        for dataset_name in set(dataset_labels):
            count = dataset_labels.count(dataset_name)
            print(f"  {dataset_name}: {count} samples")

        return texts, dataset_labels, scores

    def run(self):
        """Execute the complete visualization pipeline for mixed datasets."""
        print("=" * 60)
        print("DATASET MIX VISUALIZATION")
        print("=" * 60)

        # Get combined data
        print("\nPreparing combined dataset...")
        texts, dataset_labels, scores = self.get_combined_data()

        # Deduplicate while preserving order (keeps first occurrence's dataset label)
        print("\nDeduplicating texts while preserving dataset labels...")
        seen = {}
        unique_texts = []
        unique_labels = []
        unique_scores = []

        for text, label, score in zip(texts, dataset_labels, scores):
            if text not in seen:
                seen[text] = True
                unique_texts.append(text)
                unique_labels.append(label)
                unique_scores.append(score)

        print(
            f"After deduplication: {len(unique_texts)} unique texts (removed {len(texts) - len(unique_texts)} duplicates)"
        )

        # Load model and encode
        print("\nLoading model and encoding texts...")
        model = SentenceTransformer(self.config.model_path)
        embeddings = model.encode(
            unique_texts, normalize_embeddings=True, show_progress_bar=True
        )
        embeddings = torch.tensor(embeddings)

        # Calculate similarities
        print("\nCalculating similarities...")
        all_similarity = similarity_function(embeddings, embeddings)
        all_similarity_entailment = similarity_function_entailment(
            embeddings, embeddings
        )

        # Initialize 2D embeddings
        print("\nInitializing 2D embeddings...")
        size = len(unique_texts)
        small_dim_embeddings = self.optimizer.initialize_embeddings(
            size, dim=2, nd_embeddings=embeddings
        )

        # Get target probability distributions
        if self.config.use_flattened_entailment:
            target_prob_entailment = self.optimizer.get_probability_dist_flattened(
                all_similarity_entailment, self.config.sigma_n_d
            )
        elif self.config.use_tsne_variance:
            target_prob_entailment = self.optimizer.get_probability_dist(
                all_similarity_entailment, self.config.sigma_n_d
            )
        else:
            target_prob_entailment = self.optimizer.get_probability_dist(
                all_similarity_entailment, self.config.sigma_n_d
            )

        if self.config.use_flattened_intersection:
            target_prob_intersection = self.optimizer.get_probability_dist_flattened(
                all_similarity, self.config.sigma_n_d
            )
        elif self.config.use_joint:
            target_prob_intersection = self.optimizer.get_probability_dist_joint(
                all_similarity, self.config.sigma_n_d
            )
        elif self.config.use_tsne_variance:
            target_prob_intersection = self.optimizer.get_probability_dist_target(
                all_similarity, self.config.sigma_n_d
            )
        else:
            target_prob_intersection = self.optimizer.get_probability_dist(
                all_similarity, self.config.sigma_n_d
            )

        # Optimize embeddings
        print("\nOptimizing 2D embeddings...")
        small_dim_embeddings = self.optimizer.optimize_embeddings(
            small_dim_embeddings, target_prob_intersection, target_prob_entailment
        )

        # Calculate final similarities for correlation
        small_dim_similarity = self.optimizer.intersection_function(
            small_dim_embeddings,
            small_dim_embeddings,
            volume_temp=1,
            intersection_temp=1e-3,
        )

        size = small_dim_similarity.shape[0]
        mask = ~torch.eye(size, dtype=torch.bool, device=small_dim_similarity.device)

        small_dim_similarity_entailment = self.optimizer.entailment_function(
            small_dim_embeddings,
            small_dim_embeddings,
            volume_temp=1,
        )

        all_similarity = similarity_function(embeddings, embeddings)
        all_similarity_entailment = similarity_function_entailment(
            embeddings, embeddings
        )

        # Convert to nicer format for correlations
        small_dim_similarity = small_dim_similarity[mask].view(size, size - 1)
        small_dim_similarity_entailment = small_dim_similarity_entailment[mask].view(
            size, size - 1
        )
        all_similarity = all_similarity[mask].view(size, size - 1)
        all_similarity_entailment = all_similarity_entailment[mask].view(size, size - 1)

        # Create visualization points
        print("\nCreating visualization...")
        if self.config.use_square:
            center_delta_small = [
                CenterScalarDeltaBoxTensor.from_split(i) for i in small_dim_embeddings
            ]
            center_delta_large = [
                CenterDeltaBoxTensor.from_split(i) for i in embeddings
            ]
        else:
            center_delta_small = [
                CenterDeltaBoxTensor.from_split(i) for i in small_dim_embeddings
            ]
            center_delta_large = [
                CenterDeltaBoxTensor.from_split(i) for i in embeddings
            ]

        # Calculate volumes and correlations
        with torch.no_grad():
            low_dim_volumes = [i.log_soft_volume().item() for i in center_delta_small]
            high_dim_volumes = [i.log_soft_volume().item() for i in center_delta_large]

        correlation_stats = self.correlation_calc.get_correlation_stats(
            low_dim_volumes,
            high_dim_volumes,
            all_similarity,
            small_dim_similarity,
            all_similarity_entailment,
            small_dim_similarity_entailment,
            small_dim_embeddings,
        )
        self.correlation_calc.print_correlation_stats(
            correlation_stats,
            all_similarity_entailment.detach().flatten(),
            small_dim_similarity_entailment.detach().flatten(),
            all_similarity.detach().flatten(),
            small_dim_similarity.detach().flatten(),
        )

        # Create visualization points with dataset labels
        req_points = [
            [embedding.z.tolist(), embedding.Z.tolist()]
            for embedding in center_delta_small
        ]

        points = [
            {
                "coords": [x, y],
                "prompt": unique_texts[i],
                "volume": max((y[0] - x[0]) * (y[1] - x[1]), 1e-8),
                "dataset": unique_labels[i],  # Dataset identifier for color coding
            }
            for i, (x, y) in enumerate(req_points)
        ]

        # Generate HTML with dataset-based coloring
        html_content = HTMLGenerator.create_dataset_visualization(
            points, "mixed_datasets"
        )

        # Save visualization
        output_filename = "dataset_mix_visualization_tsne_large_diff_setting_diff.html"
        with open(output_filename, "w", encoding="utf-8") as f:
            f.write(html_content)

        print(f"\n{'=' * 60}")
        print(f"Visualization saved as '{output_filename}'")
        print(f"Total points visualized: {len(points)}")
        print(f"{'=' * 60}")

        return html_content


def parse_arguments() -> DatasetMixConfig:
    """Parse command line arguments and return configuration."""
    parser = argparse.ArgumentParser(
        description="Multi-Dataset Box Visualization with Color Coding"
    )

    # Dataset selection and sizes
    parser.add_argument(
        "--ultrafeedback_size",
        type=int,
        default=0,
        help="Number of samples from UltraFeedback dataset (0 to skip)",
    )
    parser.add_argument(
        "--wildchat_size",
        type=int,
        default=0,
        help="Number of samples from WildChat dataset (0 to skip)",
    )
    parser.add_argument(
        "--hierarchical_size",
        type=int,
        default=0,
        help="Number of samples from Hierarchical dataset (0 to skip)",
    )
    parser.add_argument(
        "--wildbench_size",
        type=int,
        default=0,
        help="Number of samples from WildBench dataset (0 to skip)",
    )

    # Visualization parameters
    parser.add_argument(
        "-s2",
        "--sigma_two_d",
        type=float,
        default=1.0,
        help="Sigma for 2D probability distribution",
    )
    parser.add_argument(
        "-snd",
        "--sigma_n_d",
        type=float,
        default=1.0,
        help="Sigma for N-D probability distribution",
    )
    parser.add_argument(
        "-int_factor",
        "--intersection_factor",
        type=float,
        default=1.0,
        help="Weight for intersection loss",
    )
    parser.add_argument(
        "-ent_factor",
        "--entailment_factor",
        type=float,
        default=0.0,
        help="Weight for entailment loss",
    )
    parser.add_argument("--use_square", action="store_true", help="Use square boxes")
    parser.add_argument(
        "--use_joint", action="store_true", help="Use joint probability"
    )
    parser.add_argument(
        "--use_opposite_entailment",
        action="store_true",
        help="Use entailment in the opposite direction",
    )
    parser.add_argument(
        "--use_tsne", action="store_true", help="Use t-SNE initialization"
    )
    parser.add_argument(
        "--use_flattened_entailment",
        action="store_true",
        help="Use flattened for entailment",
    )
    parser.add_argument(
        "--use_flattened_intersection",
        action="store_true",
        help="Use flattened for intersection",
    )
    parser.add_argument(
        "--use_tsne_variance",
        action="store_true",
        help="Use t-SNE variance method for target",
    )
    parser.add_argument(
        "--model_path",
        default="./outputs/models/pretrained_ds50000_box_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksTrue_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0/",
        help="Path to the trained model",
    )
    parser.add_argument(
        "--required_model",
        default="llama-2-13b-chat",
        help="Model name to filter responses (for UltraFeedback)",
    )

    args = parser.parse_args()

    return DatasetMixConfig(
        sigma_two_d=args.sigma_two_d,
        sigma_n_d=args.sigma_n_d,
        intersection_factor=args.intersection_factor,
        entailment_factor=args.entailment_factor,
        use_square=args.use_square,
        use_joint=args.use_joint,
        use_opposite_entailment=args.use_opposite_entailment,
        use_tsne=args.use_tsne,
        use_flattened_intersection=args.use_flattened_intersection,
        use_flattened_entailment=args.use_flattened_entailment,
        use_tsne_variance=args.use_tsne_variance,
        model_path=args.model_path,
        required_model=args.required_model,
        dataset_sizes={
            "ultrafeedback": args.ultrafeedback_size,
            "wildchat": args.wildchat_size,
            "hierarchical": args.hierarchical_size,
            "wildbench": args.wildbench_size,
        },
    )


def main():
    """Main execution function."""
    config = parse_arguments()
    visualizer = DatasetMixVisualizer(config)

    # Add datasets based on CLI arguments
    if config.dataset_sizes["ultrafeedback"] > 0:
        visualizer.add_ultrafeedback_dataset(config.dataset_sizes["ultrafeedback"])

    if config.dataset_sizes["wildchat"] > 0:
        visualizer.add_wildchat_dataset(config.dataset_sizes["wildchat"])

    if config.dataset_sizes["hierarchical"] > 0:
        visualizer.add_hierarchical_dataset(config.dataset_sizes["hierarchical"])

    if config.dataset_sizes["wildbench"] > 0:
        visualizer.add_wildbench_dataset(config.dataset_sizes["wildbench"])

    # Check if any datasets were added
    if not visualizer.datasets:
        print(
            "Error: No datasets specified! Use --ultrafeedback_size, --wildchat_size, etc."
        )
        print("\nExample usage:")
        print(
            "  python dataset_mix_refactored.py --ultrafeedback_size 50 --wildchat_size 30"
        )
        return

    # Run visualization
    visualizer.run()


if __name__ == "__main__":
    main()
