#!/usr/bin/env python3
"""
Box t-SNE Visualization

This script creates HTML visualizations of box embeddings with score-based coloring.
It uses hierarchical data and supports various similarity functions and optimization strategies.
"""

import argparse
import json
import os
import random
import sys
from dataclasses import dataclass
from typing import Dict, List, Tuple

import torch
from datasets import load_dataset, load_from_disk
from sentence_transformers import SentenceTransformer

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from box_viz_utils import (BaseBoxEmbeddingOptimizer, BaseConfig,
                           BaseDatasetLoader, CorrelationCalculator,
                           DataProcessor, HTMLGenerator)

from box.box_wrapper import CenterDeltaBoxTensor, CenterScalarDeltaBoxTensor
from box_sentence_trainer import MLPHead  # noqa
from box_similarity import similarity_function, similarity_function_entailment


@dataclass
class Config(BaseConfig):
    """Configuration class for box t-SNE visualization."""

    dataset_size: int = 100


class DatasetLoader(BaseDatasetLoader):
    """Handles loading and preprocessing of datasets for t-SNE visualization."""

    def load_ultrafeedback_dataset_multi_model(
        self, models: List[str]
    ) -> Tuple[List[Dict], Dict[str, List[float]]]:
        """
        Load UltraFeedback dataset with prompts common to all specified models.

        Args:
            models: List of model names to find common prompts for
            count: Number of common elements to return

        Returns:
            Tuple of:
                - List of dicts with 'anchor' (prompt) and 'positive' (first model's response)
                - Dict mapping model names to their scores for each prompt
        """
        ds = load_dataset("openbmb/UltraFeedback", split="train")

        # Filter to only include samples that have ALL specified models
        ds = ds.filter(lambda x: all(model in x["models"] for model in models))

        results = []
        scores_by_model = {model: [] for model in models}

        for sample in ds:
            # Extract scores for each model
            model_scores = {}
            model_responses = {}
            for completion in sample["completions"]:
                if completion["model"] in models:
                    model_scores[completion["model"]] = completion["overall_score"]
                    # model_scores[completion["model"]] = completion["fine-grained_score"] * 2
                    model_responses[completion["model"]] = completion["response"]

            # Verify all models have scores
            if len(model_scores) == len(models):
                results.append(
                    {
                        "anchor": sample["instruction"],
                        "positive": model_responses[
                            models[0]
                        ],  # Use first model's response
                    }
                )
                for model in models:
                    scores_by_model[model].append(model_scores[model])

        return results, scores_by_model

    def load_ultrafeedback_dataset(self) -> List[Dict]:
        """Load and preprocess UltraFeedback dataset."""
        ds = load_dataset("openbmb/UltraFeedback", split="train")
        ds = ds.filter(lambda x: self.config.required_model in x["models"])
        ds = ds.map(self._change_dataset_format)
        train_dataset = ds.remove_columns(
            [
                "source",
                "instruction",
                "models",
                "completions",
                "correct_answers",
                "incorrect_answers",
            ]
        )
        return train_dataset

    def load_wildchat_dataset(self) -> List[str]:
        """Load and preprocess WildChat dataset."""
        wild_chat_dataset = (
            load_dataset("allenai/WildChat", split="train[:10000]")
            .filter(lambda x: x["turn"] == 1 and x["language"] == "English")
            .map(self._preprocess_wild_chat)
        )
        return [i["instruction"] for i in wild_chat_dataset]

    def load_hierarchical_dataset(self) -> List[str]:
        """Load hierarchical dataset from total_list.json."""
        with open("total_list.json") as fp:
            general_order_dataset = json.load(fp)

        # Flatten the hierarchical structure
        general_order_dataset_flattened = [
            i for sublist in general_order_dataset for i in sublist
        ]
        return general_order_dataset_flattened[: self.config.dataset_size]

    def load_entailment_dataset(self):
        """Load entailment dataset."""
        return load_from_disk("./data/raw/datasets/new_entailment_dataset_with_sister/")

    def process_wildbench_dataset(self, train_dataset, target_llm, scorer_llm):
        session_ids = [i["session_id"] for i in train_dataset]
        anchor = [i["content"] for i in train_dataset]

        # Create visualization dataset
        # visualization_dataset = DataProcessor.deduplicate_preserve_order(corpus)

        # Initialize scores with zeros
        scores = [0 for i in train_dataset]

        # Load scores from wildbench results
        with open(
            "./skillverse/result_dirs/wild_bench_v2/gpt-wildbench-results.json", "r"
        ) as f:
            gpt3_5_data = json.load(f)
            # Create mapping from session_id to score
            score_map = {item["session_id"]: int(item["score"]) for item in gpt3_5_data}

            # Assign scores based on session_id mapping
            for idx, session_id in enumerate(session_ids):
                if idx < len(scores) and session_id in score_map:
                    scores[idx] = score_map[session_id]

    def load_wildbench_dataset(self):
        self.dataset = load_dataset(
            "allenai/WildBench",
            "v2",
            split="test",
        )

        self.dataset = self.dataset.filter(
            lambda example: len(example["conversation_input"]) == 1
        )

        # Return both content and session_id to preserve mapping
        return [
            {
                "content": i["conversation_input"][0]["content"],
                "session_id": i["session_id"],
            }
            for i in self.dataset
        ]


class TSNEDataProcessor(DataProcessor):
    """Handles data processing specific to t-SNE visualization."""

    @staticmethod
    def prepare_visualization_dataset(
        train_dataset, hierarchical_data, entailment_data, dataset_size: int
    ) -> Tuple[List[str], List[float]]:
        """Prepare the final dataset for visualization."""
        # Extract corpus and session IDs from train dataset
        corpus = [i["anchor"] for i in train_dataset][:dataset_size]
        # session_ids = [i["session_id"] for i in train_dataset][:dataset_size]

        # Create visualization dataset
        visualization_dataset = DataProcessor.deduplicate_preserve_order(corpus)

        # Initialize scores with zeros
        print(train_dataset)
        print(train_dataset[0])
        # WARNING: Remember this
        # scores = [i["score"] for i in train_dataset][:dataset_size]
        scores = [0 for i in train_dataset][:dataset_size]

        # Load scores from wildbench results
        # with open("./skillverse/result_dirs/wild_bench_v2/gpt-wildbench-results.json", "r") as f:
        # # with open("../sk", "r") as f:
        #     gpt3_5_data = json.load(f)
        #     # Create mapping from session_id to score
        #     score_map = {item["session_id"]: int(item["score"]) for item in gpt3_5_data}
        #
        #     # Assign scores based on session_id mapping
        #     for idx, session_id in enumerate(session_ids):
        #         if idx < len(scores) and session_id in score_map:
        #             scores[idx] = score_map[session_id]

        return visualization_dataset, scores


class BoxEmbeddingOptimizer(BaseBoxEmbeddingOptimizer):
    """Handles box embedding optimization and training for t-SNE visualization."""

    pass


class BoxTSNEVisualizer:
    """Main class that orchestrates the box t-SNE visualization."""

    def __init__(self, config: Config):
        self.config = config
        random.seed(config.random_seed)
        torch.manual_seed(config.random_seed)

        config.num_epochs = 5000

        self.data_loader = DatasetLoader(config)
        self.optimizer = BoxEmbeddingOptimizer(config)
        self.correlation_calc = CorrelationCalculator()

    def run(self):
        """Execute the complete visualization pipeline."""
        print("Loading datasets...")
        # train_dataset = self.data_loader.load_ultrafeedback_dataset()
        # models = ["llama-2-13b-chat", "alpaca-7b", "vicuna-33b"]
        models = ["llama-2-13b-chat", "llama-2-7b-chat", "llama-2-70b-chat"]
        train_dataset, scores_by_models = (
            self.data_loader.load_ultrafeedback_dataset_multi_model(models)
        )
        # train_dataset = self.data_loader.load_wildbench_dataset()
        hierarchical_data = self.data_loader.load_hierarchical_dataset()
        entailment_dataset = self.data_loader.load_entailment_dataset()

        print("Preparing visualization dataset...")
        visualization_dataset, scores = TSNEDataProcessor.prepare_visualization_dataset(
            train_dataset,
            hierarchical_data,
            entailment_dataset,
            self.config.dataset_size,
        )

        print(f"Final dataset size: {len(visualization_dataset)}")

        # Load model and encode
        print("Loading model and encoding texts...")
        # model_path = "./outputs/models/new_pretrained_with_synth/"
        # model_path = "./outputs/models/pretrained_ds50000_box_bs2048_mbs8_lr2e-05_vt1.0_it0.1_linksFalse_new_entailment_dataset_with_sister_with_negative_grad_norm_1.0"
        model_path = "./outputs/models/pretrained_ds50000_box_bs2048_mbs8_lr2e-05_vt1.0_it0.001_linksTrue_new_entailment_dataset_with_sister_with_negative_synth_neg_grad_norm_1.0/"
        # model_path = "./outputs/models/pretrained_100000_use_simcse_False_new_entailment_corrected_sister_with_mnli_with_negative/"
        model = SentenceTransformer(model_path)
        embeddings = model.encode(
            visualization_dataset, normalize_embeddings=True, show_progress_bar=True
        )
        embeddings = torch.tensor(embeddings[: self.config.dataset_size])

        # Calculate similarities
        print("Calculating similarities...")
        all_similarity = similarity_function(embeddings, embeddings)
        all_similarity_entailment = similarity_function_entailment(
            embeddings, embeddings
        )

        # Initialize embeddings
        print("Initializing embeddings...")
        size = len(visualization_dataset)
        small_dim_embeddings = self.optimizer.initialize_embeddings(
            size, dim=2, nd_embeddings=embeddings
        )

        # Get target probability distributions
        # TODO: FIx this. the use flattened isnt correct for similarity
        if self.config.use_flattened_entailment:
            target_prob_entailment = self.optimizer.get_probability_dist_flattened(
                all_similarity_entailment, self.config.sigma_n_d
            )
        elif self.config.use_tsne_variance:
            target_prob_entailment = self.optimizer.get_probability_dist(
                all_similarity_entailment, self.config.sigma_n_d
            )
        else:
            target_prob_entailment = self.optimizer.get_probability_dist(
                all_similarity_entailment, self.config.sigma_n_d
            )

        if self.config.use_flattened_intersection:
            target_prob_intersection = self.optimizer.get_probability_dist_flattened(
                all_similarity, self.config.sigma_n_d
            )
        elif self.config.use_joint:
            target_prob_intersection = self.optimizer.get_probability_dist_joint(
                all_similarity, self.config.sigma_n_d
            )
        elif self.config.use_tsne_variance:
            target_prob_intersection = self.optimizer.get_probability_dist_target(
                all_similarity, self.config.sigma_n_d
            )
        else:
            target_prob_intersection = self.optimizer.get_probability_dist(
                all_similarity, self.config.sigma_n_d
            )

        # Optimize embeddings
        print("Optimizing embeddings...")
        small_dim_embeddings = self.optimizer.optimize_embeddings(
            small_dim_embeddings, target_prob_intersection, target_prob_entailment
        )

        # Calculate final similarities for correlation
        small_dim_similarity = self.optimizer.intersection_function(
            small_dim_embeddings,
            small_dim_embeddings,
            volume_temp=1,
            intersection_temp=1e-3,
        )

        size = small_dim_similarity.shape[0]
        mask = ~torch.eye(size, dtype=torch.bool, device=small_dim_similarity.device)

        small_dim_similarity_entailment = self.optimizer.entailment_function(
            small_dim_embeddings,
            small_dim_embeddings,
            volume_temp=1,
            # intersection_temp=0.001,
        )

        all_similarity = similarity_function(embeddings, embeddings)
        all_similarity_entailment = similarity_function_entailment(
            embeddings, embeddings
        )

        # converting into nicer things for correlations
        small_dim_similarity = small_dim_similarity[mask].view(size, size - 1)
        small_dim_similarity_entailment = small_dim_similarity_entailment[mask].view(
            size, size - 1
        )
        all_similarity = all_similarity[mask].view(size, size - 1)
        all_similarity_entailment = all_similarity_entailment[mask].view(size, size - 1)

        # Create visualization points
        print("Creating visualization...")
        if self.config.use_square:
            center_delta_small = [
                CenterScalarDeltaBoxTensor.from_split(i) for i in small_dim_embeddings
            ]
            center_delta_large = [
                CenterDeltaBoxTensor.from_split(i) for i in embeddings
            ]
        else:
            center_delta_small = [
                CenterDeltaBoxTensor.from_split(i) for i in small_dim_embeddings
            ]
            center_delta_large = [
                CenterDeltaBoxTensor.from_split(i) for i in embeddings
            ]

        # Calculate volumes and correlations
        with torch.no_grad():
            low_dim_volumes = [i.log_soft_volume().item() for i in center_delta_small]
            high_dim_volumes = [i.log_soft_volume().item() for i in center_delta_large]

        correlation_stats = self.correlation_calc.get_correlation_stats(
            low_dim_volumes,
            high_dim_volumes,
            all_similarity,
            small_dim_similarity,
            all_similarity_entailment,
            small_dim_similarity_entailment,
            small_dim_embeddings,
        )
        self.correlation_calc.print_correlation_stats(
            correlation_stats,
            all_similarity_entailment.detach().flatten(),
            small_dim_similarity_entailment.detach().flatten(),
            all_similarity.detach().flatten(),
            small_dim_similarity.detach().flatten(),
        )

        # Create visualization points
        req_points = [
            [embedding.z.tolist(), embedding.Z.tolist()]
            for embedding in center_delta_small
        ]

        all_points = []
        if len(models) > 1:
            for i, model in enumerate(models):
                points_curr = [
                    {
                        "coords": [x, y],
                        "prompt": visualization_dataset[i],
                        "volume": max((y[0] - x[0]) * (y[1] - x[1]), 1e-8),
                        "score": scores_by_models[model][i],
                    }
                    for i, (x, y) in enumerate(req_points)
                ]
                all_points.append(points_curr)

        # points = [
        #     {
        #         "coords": [x, y],
        #         "prompt": visualization_dataset[i],
        #         "volume": max((y[0] - x[0]) * (y[1] - x[1]), 1e-8),
        #         "score": scores[i],
        #     }
        #     for i, (x, y) in enumerate(req_points)
        # ]

        # Generate HTML
        for i, model in enumerate(models):
            html_content = HTMLGenerator.create_score_visualization(
                all_points[i],
                self.config.required_model,  # required model is not used
            )

            # output_filename = f"rectangle_visualization_mixed_ultrafeedback_{model}_{self.config.dataset_size}_only_entailment.html"
            output_filename = f"rectangle_visualization_mixed_ultrafeedback_{model}_{self.config.dataset_size}_size_comparison.html"
            with open(output_filename, "w", encoding="utf-8") as f:
                f.write(html_content)

            print(f"Visualization saved as '{output_filename}'")

        return html_content


def parse_arguments() -> Config:
    """Parse command line arguments and return configuration."""
    parser = argparse.ArgumentParser(description="Box t-SNE Visualization")

    parser.add_argument(
        "-s2",
        "--sigma_two_d",
        type=float,
        default=1.0,
        help="Sigma for 2D probability distribution",
    )
    parser.add_argument(
        "-snd",
        "--sigma_n_d",
        type=float,
        default=1.0,
        help="Sigma for N-D probability distribution",
    )
    parser.add_argument(
        "-int_factor",
        "--intersection_factor",
        type=float,
        default=1.0,
        help="Weight for intersection loss",
    )
    parser.add_argument(
        "-ent_factor",
        "--entailment_factor",
        type=float,
        default=0.0,
        help="Weight for entailment loss",
    )
    parser.add_argument(
        "--use_square", action="store_true", help="Use square probability"
    )
    parser.add_argument(
        "--use_joint", action="store_true", help="Use joint probability"
    )
    parser.add_argument(
        "--use_opposite_entailment",
        action="store_true",
        help="Use entailment in the opposite direction",
    )
    parser.add_argument("--use_tsne", action="store_true", help="Use t-SNE centers")
    parser.add_argument(
        "--use_flattened_entailment",
        action="store_true",
        help="Use flattened for entailment",
    )
    parser.add_argument(
        "--use_flattened_intersection",
        action="store_true",
        help="Use flattened for entailment",
    )
    parser.add_argument(
        "--save_location", help="Location where the embeddings will be saved"
    )
    parser.add_argument(
        "--saved_embeddings_location", help="Location of saved embeddings"
    )
    parser.add_argument(
        "--required_model",
        default="llama-2-13b-chat",
        help="Model name to filter responses",
    )
    parser.add_argument(
        "--dataset_size", type=int, default=100, help="Number of samples to visualize"
    )
    parser.add_argument(
        "--use_tsne_variance",
        action="store_true",
        help="The tsne variance method for target",
    )

    args = parser.parse_args()

    return Config(
        sigma_two_d=args.sigma_two_d,
        sigma_n_d=args.sigma_n_d,
        intersection_factor=args.intersection_factor,
        entailment_factor=args.entailment_factor,
        use_square=args.use_square,
        use_joint=args.use_joint,
        use_opposite_entailment=args.use_opposite_entailment,
        use_tsne=args.use_tsne,
        use_flattened_intersection=args.use_flattened_intersection,
        use_flattened_entailment=args.use_flattened_entailment,
        save_location=args.save_location,
        saved_embeddings_location=args.saved_embeddings_location,
        required_model=args.required_model,
        dataset_size=args.dataset_size,
        use_tsne_variance=args.use_tsne_variance,
    )


def main():
    """Main execution function."""
    config = parse_arguments()
    visualizer = BoxTSNEVisualizer(config)
    visualizer.run()


if __name__ == "__main__":
    main()
