                                                                                   
 
                                                                       
                                                                                        
 
        
                                             
                                       
 
                                                         
                                                            
 
                                                               
                                                       
 
"""
This script performs a comparative analysis of two vector databases, a 'default'
and an 'expanded' one. It fetches prompt embeddings and metadata, cleans the data by
re-classifying outliers, and generates publication-quality comparison plots for
the TMLR journal format.

Key features include:
- Robust data fetching from ChromaDB.
- Efficient, multi-threaded re-classification of prompts against a predefined taxonomy.
- Caching of processed data using Apache Parquet for faster subsequent runs.
- Optional down-sampling for quick testing and iteration.
- Generation of side-by-side t-SNE and category distribution plots.
- Automatic export of plots in both PDF and high-resolution PNG formats.
- Dynamic legend placement to prevent overlap with plot axes.
"""

import logging
import os
import sys
import json
import uuid
import math
import argparse
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple, Set
from concurrent.futures import ThreadPoolExecutor, as_completed

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from pydantic import ValidationError, BaseModel, Field
from tqdm import tqdm

                                                                            

              
SAFE_CATEGORIES = [
    "information_retrieval", "problem_solving", "learning",
    "content_creation", "ethical_inquiry"
]
UNSAFE_CATEGORIES = [
    "s1_violent_crimes", "s2_non_violent_crimes", "s3_sex_related_crimes",
    "s4_child_sexual_exploitation", "s5_defamation", "s6_specialized_advice",
    "s7_privacy", "s8_intellectual_property", "s9_indiscriminate_weapons",
    "s10_hate", "s11_suicide_self_harm", "s12_sexual_content", "s13_elections"
]
PREDEFINED_CATEGORIES: Set[str] = set(SAFE_CATEGORIES + UNSAFE_CATEGORIES)

                                                 
class NLPFeatures(BaseModel):
    perplexity: Optional[float] = None

class InputPromptRecord(BaseModel):
    original_prompt: str
    prompt_id: uuid.UUID = Field(default_factory=uuid.uuid4)
    label: int                            
    source_file: Optional[str] = None
    prompt_category: Optional[str] = None

class DatabasePromptRecord(InputPromptRecord, NLPFeatures):
    embedding: Optional[List[float]] = None

                                                     
try:
    import chromadb
except ImportError:
    print("Error: chromadb is not installed. Please run 'pip install chromadb'")
    sys.exit(1)

def get_config():
    """Mock config function for standalone use."""
    return {
        "vector_database": {
            "collection_name": "fortress_prompts_collection"
        },
        "clustering": {
            "cluster_field_name": "prompt_category"
        }
    }

class VectorStoreInterface(ABC):
    @abstractmethod
    def query_similar(self, embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
        pass
    @abstractmethod
    def assign_clusters_with_weights_to_new_prompt(self, prompt_embedding: List[float], cluster_field_name: str, top_k_neighbors: int, where_filter: Optional[Dict[str, Any]]) -> List[Tuple[Any, float]]:
        pass
    @abstractmethod
    def batch_assign_clusters(self, prompt_embeddings: List[List[float]], cluster_field_name: str, top_k_neighbors: int, where_filter: Optional[Dict[str, Any]]) -> List[str]:
        pass

class ChromaVectorStore(VectorStoreInterface):
    """A simplified, self-contained version of your ChromaVectorStore for this analysis script."""
    def __init__(self, collection_name: str, db_path: str):
        self.db_path = db_path
        self.collection_name = collection_name
        if not os.path.exists(db_path):
            raise FileNotFoundError(f"Database path does not exist: {db_path}")

        try:
            self.client = chromadb.PersistentClient(path=self.db_path)
            self.collection = self.client.get_collection(name=self.collection_name)
            logger.info(f"Successfully connected to ChromaDB collection '{self.collection_name}' at '{self.db_path}'.")
        except Exception as e:
            logger.error(f"Failed to connect to ChromaDB at {db_path}: {e}", exc_info=True)
            raise ConnectionError(f"ChromaDB connection failed for collection '{self.collection_name}'") from e

    def get_collection_size(self) -> int:
        return self.collection.count()

    def query_similar(self, embedding: List[float], top_k: int = 10, filters: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
        raise NotImplementedError

    def batch_assign_clusters(
        self,
        prompt_embeddings: List[List[float]],
        cluster_field_name: str,
        top_k_neighbors: int,
        where_filter: Optional[Dict[str, Any]]
    ) -> List[str]:
        """
        Assigns clusters to a batch of embeddings in a single, efficient query.
        Returns a list of cluster names corresponding to each input embedding.
        """
        if not self.collection:
            logger.warning("Collection not initialized. Cannot assign clusters.")
            return ['Uncategorized'] * len(prompt_embeddings)

        if not prompt_embeddings:
            return []

        try:
            query_results = self.collection.query(
                query_embeddings=prompt_embeddings,
                n_results=int(top_k_neighbors),
                where=where_filter,
                include=["metadatas"]
            )

            results_metadatas = query_results.get('metadatas', [])
            assigned_clusters = []

            for metadatas in results_metadatas:
                if metadatas:
                    top_neighbor_meta = metadatas[0]
                    cluster = top_neighbor_meta.get(cluster_field_name, 'Uncategorized')
                    assigned_clusters.append(cluster)
                else:
                    assigned_clusters.append('Uncategorized')

            return assigned_clusters
        except Exception as e:
            logger.error(f"Error during batch cluster assignment: {e}", exc_info=True)
            return ['Uncategorized'] * len(prompt_embeddings)

    def assign_clusters_with_weights_to_new_prompt(
        self,
        prompt_embedding: List[float],
        cluster_field_name: str = "prompt_category",
        top_k_neighbors: int = 1,
        where_filter: Optional[Dict[str, Any]] = None
    ) -> List[Tuple[Any, float]]:
        assigned_cluster = self.batch_assign_clusters(
            prompt_embeddings=[prompt_embedding],
            cluster_field_name=cluster_field_name,
            top_k_neighbors=top_k_neighbors,
            where_filter=where_filter
        )[0]

        if assigned_cluster != 'Uncategorized':
            return [(assigned_cluster, 1.0)]
        return []

                                              

                           
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
console = Console()

def setup_matplotlib_for_tmlr():
    """Sets Matplotlib parameters for TMLR publication-quality figures."""
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.size': 10,
        'axes.labelsize': 10,
        'axes.titlesize': 12,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'legend.fontsize': 9,
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'serif'],
        'text.usetex': False,
        'figure.figsize': (10, 5),
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.spines.top': True,
        'axes.spines.right': True,
        'axes.spines.left': True,
        'axes.spines.bottom': True,
    })
    logger.info("Matplotlib configured for TMLR styling.")

def fetch_all_data_for_analysis(vector_store: ChromaVectorStore, batch_size: int = 5000) -> pd.DataFrame:
    """Fetches all documents, embeddings, and metadata from a collection."""
    all_doc_ids, all_embeddings, all_metadatas = [], [], []
    offset = 0
    total_docs = vector_store.get_collection_size()
    console.print(f"Fetching {total_docs} documents from [cyan]{vector_store.collection_name}[/cyan]...")

    with tqdm(total=total_docs, desc="Fetching data") as pbar:
        while True:
            try:
                results = vector_store.collection.get(
                    limit=batch_size, offset=offset, include=["embeddings", "metadatas"]
                )
            except Exception as e:
                logger.error(f"Error fetching batch from ChromaDB: {e}", exc_info=True)
                break
            if not results or not results.get('ids'): break

            batch_ids = results['ids']
            batch_embeddings = results.get('embeddings', [])
            batch_metadatas = results.get('metadatas', [])

            valid_indices = [i for i, emb in enumerate(batch_embeddings) if emb is not None and len(emb) > 0]
            if not valid_indices:
                if len(batch_ids) < batch_size: break
                offset += len(batch_ids)
                pbar.update(len(batch_ids))
                continue

            all_doc_ids.extend([batch_ids[i] for i in valid_indices])
            all_embeddings.extend([batch_embeddings[i] for i in valid_indices])
            all_metadatas.extend([batch_metadatas[i] for i in valid_indices])

            pbar.update(len(batch_ids))
            if len(batch_ids) < batch_size: break
            offset += len(batch_ids)

    df_data = []
    for i, doc_id in enumerate(all_doc_ids):
        record = {'id': doc_id, 'embedding': all_embeddings[i]}
        if all_metadatas[i]: record.update(all_metadatas[i])
        df_data.append(record)
    return pd.DataFrame(df_data)

def reclassify_and_filter_prompts(df: pd.DataFrame, vector_store: ChromaVectorStore, cluster_field: str, batch_size: int = 1024, max_workers: int = 12) -> pd.DataFrame:
    """
    Reclassifies prompts not in PREDEFINED_CATEGORIES based on their nearest neighbor
    using efficient batching and multi-threaded. Returns a DataFrame containing
    only prompts from the predefined set.
    """
    console.print(f"\nRe-classifying outliers against {len(PREDEFINED_CATEGORIES)} predefined categories...")

    df_predefined = df[df[cluster_field].isin(PREDEFINED_CATEGORIES)].copy()
    df_to_reclassify = df[~df[cluster_field].isin(PREDEFINED_CATEGORIES)].copy()

    if df_to_reclassify.empty:
        console.print("No outliers to reclassify. All prompts are within the predefined taxonomy.", style="green")
        return df_predefined

    console.print(f"Found {len(df_to_reclassify)} prompts to reclassify (using {max_workers} workers, batch size {batch_size})...")

    where_filter = {cluster_field: {"$in": list(PREDEFINED_CATEGORIES)}}
    embeddings_to_query = df_to_reclassify['embedding'].tolist()

    chunks = [embeddings_to_query[i:i + batch_size] for i in range(0, len(embeddings_to_query), batch_size)]

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_batch = {
            executor.submit(
                vector_store.batch_assign_clusters, chunk, cluster_field, 1, where_filter
            ): i for i, chunk in enumerate(chunks)
        }
        results = [None] * len(chunks)
        with tqdm(total=len(chunks), desc="Re-classifying Batches") as pbar:
            for future in as_completed(future_to_batch):
                batch_index = future_to_batch[future]
                try:
                    results[batch_index] = future.result()
                except Exception as exc:
                    logger.error(f'Batch {batch_index} generated an exception: {exc}')
                    results[batch_index] = ['Uncategorized'] * len(chunks[batch_index])
                pbar.update(1)

    new_categories = [item for sublist in results for item in sublist]

    uncategorized_count = new_categories.count('Uncategorized')
    total_reclassified = len(new_categories)
    if total_reclassified > 0:
        success_rate = (total_reclassified - uncategorized_count) / total_reclassified * 100
        console.print(
            f"Re-classification report: {total_reclassified - uncategorized_count} / {total_reclassified} "
            f"prompts successfully re-assigned ({success_rate:.1f}%). "
            f"{uncategorized_count} failed (assigned 'Uncategorized')."
        )

    df_to_reclassify.loc[:, cluster_field] = new_categories

    df_combined = pd.concat([df_predefined, df_to_reclassify])
    df_cleaned = df_combined[df_combined[cluster_field].isin(PREDEFINED_CATEGORIES)].copy()

    console.print(f"Re-classification complete. Final dataset size: {len(df_cleaned)}.", style="green")
    return df_cleaned

def plot_comparison_distribution(df_base: pd.DataFrame, df_exp: pd.DataFrame, cluster_field: str, output_path: str):
    """Plots two side-by-side count plots and saves as PDF and PNG."""
    if df_base.empty or df_exp.empty: return

    category_order = sorted(list(PREDEFINED_CATEGORIES))
    palette = sns.color_palette("viridis", n_colors=len(category_order))

    fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

    sns.countplot(
        ax=axes[0], data=df_base, y=cluster_field, hue=cluster_field,
        order=category_order, hue_order=category_order, palette=palette,
        orient='h', legend=False
    )
    axes[0].set_title("Default Database")
    axes[0].set_xlabel("Number of Prompts")
    axes[0].set_ylabel("Prompt Category")
    axes[0].grid(axis='x', linestyle='--', alpha=0.7)

    sns.countplot(
        ax=axes[1], data=df_exp, y=cluster_field, hue=cluster_field,
        order=category_order, hue_order=category_order, palette=palette,
        orient='h', legend=False
    )
    axes[1].set_title("Expanded Database")
    axes[1].set_xlabel("Number of Prompts")
    axes[1].set_ylabel("")
    axes[1].grid(axis='x', linestyle='--', alpha=0.7)

    fig.tight_layout()
    
              
    plt.savefig(output_path, bbox_inches='tight')
              
    png_output_path = output_path.replace('.pdf', '.png')
    plt.savefig(png_output_path, bbox_inches='tight')

    plt.close()
    logger.info(f"Comparison distribution plot saved to {output_path} and {png_output_path}")

def plot_comparison_tsne(df_base: pd.DataFrame, df_exp: pd.DataFrame, hue_field: str, output_path: str, perplexity: float = 30.0):
    """
    Performs t-SNE and plots two side-by-side scatter plots with a robust,
    shared legend placed below the axes. Saves as both PDF and PNG.
    """
    if df_base.empty or df_exp.empty: return

    fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)

    data_to_plot = {
        "Default Database": df_base,
        "Expanded Database": df_exp
    }
    all_hues = sorted(list(pd.concat([df_base[hue_field], df_exp[hue_field]]).dropna().unique()))

                                                   
    if hue_field == 'label' and set(all_hues) == {'safe', 'unsafe'}:
        palette = ['#2DA02D', '#D62727']                                  
        color_map = dict(zip(['safe', 'unsafe'], palette))
    else:
                                             
        if len(all_hues) > 20:
            palette = sns.color_palette("gist_ncar", n_colors=len(all_hues))
        else:
            palette = sns.color_palette("tab20", n_colors=len(all_hues))
        color_map = dict(zip(all_hues, palette))

    for i, (title, df) in enumerate(data_to_plot.items()):
        ax = axes[i]
        console.print(f"Performing t-SNE on {len(df)} embeddings for '{title}'...")
        
        valid_embeddings = df['embedding'].apply(lambda x: isinstance(x, (list, np.ndarray)) and len(x) > 0)
        df_plot = df[valid_embeddings].copy()

        if df_plot.empty:
            logger.warning(f"No valid embeddings found for '{title}'. Skipping t-SNE plot.")
            continue

        embeddings_np = np.array(df_plot['embedding'].tolist())
                                                         
        valid_perplexity = min(perplexity, embeddings_np.shape[0] - 1)
        
        tsne = TSNE(
            n_components=2, random_state=42, perplexity=valid_perplexity,
            n_iter=400, learning_rate='auto', init='pca'
        )
        tsne_results = tsne.fit_transform(embeddings_np)
        df_plot['t-SNE Dimension 1'], df_plot['t-SNE Dimension 2'] = tsne_results[:, 0], tsne_results[:, 1]

        sns.scatterplot(
            ax=ax, x="t-SNE Dimension 1", y="t-SNE Dimension 2",
            hue=hue_field, palette=color_map, data=df_plot,
            legend=False, alpha=0.7, s=3, hue_order=all_hues,
            edgecolor='none'
        )
        ax.set_title(title, fontsize=14)
        ax.set_xlabel("t-SNE Dimension 1")
        ax.set_ylabel("t-SNE Dimension 2" if i == 0 else "")

                                      
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w',
               label=str(label).replace('_', ' ').title().replace("S ", "S"),
               markerfacecolor=color_map.get(label, 'gray'), markersize=6)
        for label in all_hues
    ]

                                                                          
    num_hues = len(all_hues)
    if num_hues > 15: ncol = 5
    elif num_hues > 8: ncol = 4
    elif num_hues > 4: ncol = 3
    else: ncol = max(1, num_hues)

                                                                 
    num_rows_in_legend = math.ceil(num_hues / ncol)
                                                                            
    bottom_margin = 0.05 + (num_rows_in_legend * 0.035)

                                                               
    fig.tight_layout(rect=[0, bottom_margin, 1, 0.96])

                                            
    fig.legend(
        handles=legend_elements,
        title=hue_field.replace('_', ' ').title(),
        loc='center',                                             
        bbox_to_anchor=(0.5, bottom_margin / 2),                                     
        ncol=ncol,
        fontsize=9,
        title_fontsize=10,
        frameon=False
    )

              
    plt.savefig(output_path, bbox_inches='tight')
              
    png_output_path = output_path.replace('.pdf', '.png')
    plt.savefig(png_output_path, bbox_inches='tight')

    plt.close()
    logger.info(f"Comparison t-SNE plot saved to {output_path} and {png_output_path}")

def check_sampling_ratio(value: str) -> float:
    """Custom type for argparse to validate the sampling ratio."""
    try:
        f_value = float(value)
        if not (0.01 <= f_value <= 1.0):
            raise argparse.ArgumentTypeError(
                f"{f_value} is an invalid sampling ratio. Must be between 0.01 and 1.0."
            )
    except ValueError:
        raise argparse.ArgumentTypeError(f"{value} is not a valid float.")
    return f_value

                              
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Database Comparison Analysis Script with Caching and Sampling."
    )
    parser.add_argument(
        '--force-recache', action='store_true',
        help="Force re-fetching and re-classifying data, ignoring any existing cache."
    )
    parser.add_argument(
        '--sampling-ratio', type=check_sampling_ratio, default=1,
        help="A float between 0.01 and 1.0 to sample the database for faster analysis. Default is 1.0 (no sampling)."
    )
    args = parser.parse_args()

    console.print(Panel(Text("Database Comparison Analysis (TMLR-Ready, v5)", style="bold blue"), expand=False))
    setup_matplotlib_for_tmlr()

    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
    from config.constants import BASE_PROJECT_DIR

    base_db_path = str(BASE_PROJECT_DIR / "data/07_vector_db/gemma3_1b_base_backup")
    exp_db_path = str(BASE_PROJECT_DIR / "data/07_vector_db/gemma3_1b_exp_backup")
    output_directory = str(BASE_PROJECT_DIR / "data/reports/comparison_analysis_reports")
    cache_directory = str(BASE_PROJECT_DIR / "data/reports/comparison_analysis_reports/analysis_cache")
    os.makedirs(output_directory, exist_ok=True)
    os.makedirs(cache_directory, exist_ok=True)

    RECLASSIFY_BATCH_SIZE = 64
    RECLASSIFY_MAX_WORKERS = max(1, (os.cpu_count() or 4) - 4)

    config = get_config()
    collection_name = config["vector_database"]["collection_name"]
    cluster_field = config["clustering"]["cluster_field_name"]

                                               
    sampling_suffix = f"_sample_{args.sampling_ratio}" if args.sampling_ratio < 1.0 else ""
    base_db_name = os.path.basename(os.path.normpath(base_db_path))
    exp_db_name = os.path.basename(os.path.normpath(exp_db_path))
    base_cache_path = os.path.join(cache_directory, f"cleaned_df_{base_db_name}{sampling_suffix}.parquet")
    exp_cache_path = os.path.join(cache_directory, f"cleaned_df_{exp_db_name}{sampling_suffix}.parquet")

    console.print(f"Reports will be saved to: [cyan]{output_directory}[/cyan]")
    console.print(f"Cache will use robust Parquet format in: [cyan]{cache_directory}[/cyan]")
    if args.sampling_ratio < 1.0:
        console.print(f"[yellow]Analysis will run on a {args.sampling_ratio:.0%} random sample of the data.[/yellow]")

    df_base_cleaned, df_exp_cleaned = None, None

    try:
        if not args.force_recache and os.path.exists(base_cache_path) and os.path.exists(exp_cache_path):
            console.print(Panel(Text("Loading from Parquet cache to accelerate startup.", style="bold yellow"), expand=False))
            df_base_cleaned = pd.read_parquet(base_cache_path)
            df_exp_cleaned = pd.read_parquet(exp_cache_path)
            console.print("Successfully loaded cleaned data from cache.", style="green")
        else:
            if args.force_recache:
                console.print("[bold yellow]--force-recache set. Ignoring and overwriting existing cache.[/bold yellow]")

            console.print("\n[bold]Step 1: Loading data from databases...[/bold]")
            vector_store_base = ChromaVectorStore(collection_name=collection_name, db_path=base_db_path)
            df_base = fetch_all_data_for_analysis(vector_store_base)

            vector_store_exp = ChromaVectorStore(collection_name=collection_name, db_path=exp_db_path)
            df_exp = fetch_all_data_for_analysis(vector_store_exp)

            if args.sampling_ratio < 1.0:
                console.print(f"\n[bold]Applying {args.sampling_ratio:.0%} sampling...[/bold]")
                df_base = df_base.sample(frac=args.sampling_ratio, random_state=42).reset_index(drop=True)
                df_exp = df_exp.sample(frac=args.sampling_ratio, random_state=42).reset_index(drop=True)
                console.print(f"Sampled base DB size: {len(df_base)}, Sampled expanded DB size: {len(df_exp)}")

            if df_base.empty or df_exp.empty:
                console.print("[bold red]One or both databases are empty after fetching/sampling. Aborting.[/bold red]")
                sys.exit(1)

            console.print("\n[bold]Step 2: Cleaning and standardizing data...[/bold]")
            for df in [df_base, df_exp]:
                df[cluster_field] = df[cluster_field].fillna('Uncategorized')
                df['label'] = df['label'].map({0.0: 'safe', 1.0: 'unsafe', 0: 'safe', 1: 'unsafe'}).fillna('Unknown')

            df_base_cleaned = reclassify_and_filter_prompts(
                df_base, vector_store_base, cluster_field,
                batch_size=RECLASSIFY_BATCH_SIZE, max_workers=RECLASSIFY_MAX_WORKERS
            )
            df_exp_cleaned = reclassify_and_filter_prompts(
                df_exp, vector_store_exp, cluster_field,
                batch_size=RECLASSIFY_BATCH_SIZE, max_workers=RECLASSIFY_MAX_WORKERS
            )

            console.print("\n[bold]Caching cleaned data to Parquet for future runs...[/bold]")
            df_base_cleaned.to_parquet(base_cache_path, index=False)
            df_exp_cleaned.to_parquet(exp_cache_path, index=False)
            console.print(f"Saved base data to: [cyan]{base_cache_path}[/cyan]")
            console.print(f"Saved expanded data to: [cyan]{exp_cache_path}[/cyan]")

        if df_base_cleaned is not None and not df_base_cleaned.empty and df_exp_cleaned is not None and not df_exp_cleaned.empty:
            console.print("\n[bold]Step 3: Generating comparison plots...[/bold]")

            dist_path = os.path.join(output_directory, f"comparison_distribution_{cluster_field}{sampling_suffix}.pdf")
            plot_comparison_distribution(df_base_cleaned, df_exp_cleaned, cluster_field=cluster_field, output_path=dist_path)

                                                                                                                        
                                                                                                                       

                                                                                                                
                                                                                                                   

            console.print(Panel(Text("Analysis complete. TMLR-ready plots generated successfully.", style="bold green"), expand=False))
        else:
            console.print("[bold red]Error: Cleaned dataframes are empty. Cannot generate plots.[/bold red]")

    except FileNotFoundError as e:
        console.print(f"[bold red]Error: Database path not found. {e}[/bold red]")
    except ConnectionError as e:
        console.print(f"[bold red]Error: Could not connect to ChromaDB. {e}[/bold red]")
    except Exception as e:
        console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
        logger.error("Unexpected error in main analysis script:", exc_info=True)