"""
Embeddings Analysis Integration

This module provides a wrapper class to integrate embeddings analysis into the
experiment pipeline, following the same patterns as other LLM components.
"""

import pathlib
import time
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
import matplotlib.pyplot as plt

from src.core.registry import embeddings_registry
from src.utils.decorator_utils import with_logger


class EmbeddingsAnalyzer:
    """
    A wrapper class to integrate embeddings analysis into experiment pipelines.

    This class extracts prompts from experiment results, generates embeddings
    analysis, and saves heatmaps to the output directory.
    """

    @with_logger
    def __init__(
        self,
        model: str = "bedrock-cohere-embed-eng-v3",
        output_formats: List[str] = None,
        **kwargs: Any,
    ):
        """
        Initialize the embeddings analyzer.

        Args:
            model: The embedding model to use
            output_formats: List of output formats for heatmaps (e.g., ['png'])
            **kwargs: Additional keyword arguments
        """
        logger.info(f"Initializing EmbeddingsAnalyzer with model: {model}")

        self.model = model
        self.output_formats = output_formats or ["png"]
        self.kwargs = kwargs

        # Will be initialized when needed
        self.embeddings_client = None

        logger.info("EmbeddingsAnalyzer initialized successfully")

    @with_logger
    def extract_prompts_from_results(
        self,
        base_results: Dict[str, Any] = None,
        optimized_results: Dict[str, Any] = None,
        prompt_optimization_results: Dict[str, Any] = None,
    ) -> List[str]:
        """
        Extract prompts from experiment results for embeddings analysis.

        Args:
            base_results: Results from base prompt run
            optimized_results: Results from optimized prompt run
            prompt_optimization_results: Results from prompt optimization step

        Returns:
            List of prompts to analyze
        """
        logger.info("Extracting prompts from experiment results")

        prompts = []

        # Extract base prompt
        if prompt_optimization_results and "base_prompt" in prompt_optimization_results:
            base_prompt = prompt_optimization_results["base_prompt"]
            prompts.append(base_prompt)
            logger.info(f"Extracted base prompt ({len(base_prompt)} chars)")

        # Extract optimized prompt
        if (
            prompt_optimization_results
            and "optimised_prompt" in prompt_optimization_results
        ):
            optimized_prompt = prompt_optimization_results["optimised_prompt"]
            # Only add if different from base prompt
            if optimized_prompt not in prompts:
                prompts.append(optimized_prompt)
                logger.info(
                    f"Extracted optimized prompt ({len(optimized_prompt)} chars)"
                )

        # If no prompts found in optimization results, try to extract from other sources
        if not prompts:
            logger.warning(
                "No prompts found in optimization results, attempting alternative extraction"
            )
            # You can add more extraction logic here based on your specific needs

        logger.info(f"Total prompts extracted: {len(prompts)}")
        return prompts

    @with_logger
    def generate_embeddings_analysis(
        self, prompts: List[str], labels: List[str] = None
    ) -> Tuple[plt.Figure, plt.Axes, pd.DataFrame]:
        """
        Generate embeddings analysis for the given prompts.

        Args:
            prompts: List of prompts to analyze
            labels: Optional labels for the prompts

        Returns:
            Tuple of (figure, axes, similarity_dataframe)
        """
        logger.info(f"Generating embeddings analysis for {len(prompts)} prompts")

        if len(prompts) < 2:
            logger.warning("Need at least 2 prompts for meaningful similarity analysis")
            if len(prompts) == 1:
                # Create a simple single-prompt analysis
                prompts = prompts + [prompts[0]]  # Duplicate for comparison

        # Initialize embeddings client if not already done
        if self.embeddings_client is None:
            logger.info("Initializing OpenAI Embeddings client")
            self.embeddings_client = embeddings_registry.create(
                "OpenAI_Embeddings", input=prompts, model=self.model, **self.kwargs
            )
        else:
            # Update the input for existing client
            self.embeddings_client.input = prompts

        # Generate embeddings and heatmap
        start_time = time.time()
        fig, ax = self.embeddings_client.generate(prompts)

        # Get the similarity matrix for additional analysis
        embeddings = self.embeddings_client.get_embeddings(prompts)
        similarity_df = self.embeddings_client.similarity_matrix(embeddings)

        duration = time.time() - start_time
        logger.info(f"Embeddings analysis completed in {duration:.2f} seconds")

        # Customize the plot if labels are provided
        if labels and len(labels) == len(prompts):
            ax.set_xticklabels(labels, rotation=45, ha="right")
            ax.set_yticklabels(labels, rotation=0)
            logger.info("Applied custom labels to heatmap")

        return fig, ax, similarity_df

    @with_logger
    def save_analysis_results(
        self,
        fig: plt.Figure,
        similarity_df: pd.DataFrame,
        output_dir: Union[str, pathlib.Path],
        experiment_name: str = "embeddings_analysis",
        config: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, str]:
        """
        Save the embeddings analysis results to files.

        Args:
            fig: The matplotlib figure to save
            similarity_df: The similarity dataframe to save
            output_dir: Directory to save results
            experiment_name: Name for the output files

        Returns:
            Dictionary with paths to saved files
        """
        logger.info(f"Saving embeddings analysis results")

        saved_files = {}

        # Use configured output paths if available
        if config and "output" in config:
            output_config = config["output"]
            heatmap_path = output_config.get("embeddings_heatmap")
            similarity_path = output_config.get("embeddings_similarity")

            if heatmap_path:
                # Create directory for heatmap
                heatmap_path = pathlib.Path(heatmap_path)
                heatmap_path.parent.mkdir(parents=True, exist_ok=True)

                # Save heatmap
                try:
                    fig.savefig(heatmap_path, bbox_inches="tight", dpi=300)
                    saved_files["heatmap"] = str(heatmap_path)
                    logger.info(f"Saved heatmap to configured path: {heatmap_path}")
                except Exception as e:
                    logger.error(f"Failed to save heatmap to configured path: {str(e)}")

            if similarity_path:
                # Create directory for similarity matrix
                similarity_path = pathlib.Path(similarity_path)
                similarity_path.parent.mkdir(parents=True, exist_ok=True)

                # Save similarity matrix as CSV
                try:
                    similarity_df.to_csv(similarity_path)
                    saved_files["similarity_csv"] = str(similarity_path)
                    logger.info(
                        f"Saved similarity matrix to configured path: {similarity_path}"
                    )
                except Exception as e:
                    logger.error(
                        f"Failed to save similarity matrix to configured path: {str(e)}"
                    )
        else:
            # Fallback to default paths
            logger.info("No configured output paths found, using defaults")
            output_dir = pathlib.Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)

            # Save heatmap in requested formats
            for fmt in self.output_formats:
                heatmap_path = (
                    output_dir / f"{experiment_name}_embeddings_heatmap.{fmt}"
                )
                try:
                    fig.savefig(heatmap_path, bbox_inches="tight", dpi=300)
                    saved_files[f"heatmap_{fmt}"] = str(heatmap_path)
                    logger.info(f"Saved heatmap as {fmt}: {heatmap_path}")
                except Exception as e:
                    logger.error(f"Failed to save heatmap as {fmt}: {str(e)}")

            # Save similarity matrix as CSV
            similarity_path = output_dir / f"{experiment_name}_similarity_matrix.csv"
            try:
                similarity_df.to_csv(similarity_path)
                saved_files["similarity_csv"] = str(similarity_path)
                logger.info(f"Saved similarity matrix: {similarity_path}")
            except Exception as e:
                logger.error(f"Failed to save similarity matrix: {str(e)}")

            # Save similarity matrix as JSON for easier programmatic access
            similarity_json_path = (
                output_dir / f"{experiment_name}_similarity_matrix.json"
            )
            try:
                similarity_df.to_json(similarity_json_path, indent=2)
                saved_files["similarity_json"] = str(similarity_json_path)
                logger.info(f"Saved similarity matrix as JSON: {similarity_json_path}")
            except Exception as e:
                logger.error(f"Failed to save similarity matrix as JSON: {str(e)}")

        logger.info(f"Successfully saved {len(saved_files)} files")
        return saved_files

    @with_logger
    def analyze_experiment_results(
        self,
        base_results: Dict[str, Any] = None,
        optimized_results: Dict[str, Any] = None,
        prompt_optimization_results: Dict[str, Any] = None,
        output_dir: Union[str, pathlib.Path] = None,
        experiment_name: str = "experiment",
        config: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """
        Complete embeddings analysis workflow for experiment results.

        Args:
            base_results: Results from base prompt run
            optimized_results: Results from optimized prompt run
            prompt_optimization_results: Results from prompt optimization step
            output_dir: Directory to save results
            experiment_name: Name for the experiment

        Returns:
            Dictionary with analysis results and file paths
        """
        logger.info(
            f"Starting complete embeddings analysis for experiment: {experiment_name}"
        )

        try:
            # Extract prompts
            prompts = self.extract_prompts_from_results(
                base_results, optimized_results, prompt_optimization_results
            )

            if not prompts:
                logger.error("No prompts found for embeddings analysis")
                return {"error": "No prompts found for analysis"}

            # Create labels for the prompts
            labels = []
            if len(prompts) >= 1:
                labels.append("Base Prompt")
            if len(prompts) >= 2:
                labels.append("Optimized Prompt")
            # Add generic labels for any additional prompts
            for i in range(2, len(prompts)):
                labels.append(f"Prompt {i+1}")

            # Generate analysis
            fig, ax, similarity_df = self.generate_embeddings_analysis(prompts, labels)

            # Save results if output directory provided
            saved_files = {}
            if output_dir or (config and "output" in config):
                saved_files = self.save_analysis_results(
                    fig, similarity_df, output_dir or "output", experiment_name, config
                )

            # Calculate similarity metrics
            similarity_metrics = self._calculate_similarity_metrics(similarity_df)

            logger.info("Embeddings analysis completed successfully")

            return {
                "prompts": prompts,
                "labels": labels,
                "similarity_matrix": similarity_df,
                "similarity_metrics": similarity_metrics,
                "saved_files": saved_files,
                "figure": fig,
                "axes": ax,
            }

        except Exception as e:
            logger.error(f"Error in embeddings analysis: {str(e)}", exc_info=True)
            return {"error": str(e)}

    def _calculate_similarity_metrics(
        self, similarity_df: pd.DataFrame
    ) -> Dict[str, float]:
        """
        Calculate summary metrics from the similarity matrix.

        Args:
            similarity_df: The similarity dataframe

        Returns:
            Dictionary with similarity metrics
        """
        metrics = {}

        if similarity_df.shape[0] >= 2:
            # Get off-diagonal values (excluding self-similarity)
            import numpy as np

            mask = np.ones(similarity_df.shape, dtype=bool)
            np.fill_diagonal(mask, False)
            off_diagonal_values = similarity_df.values[mask]

            metrics["mean_similarity"] = float(np.mean(off_diagonal_values))
            metrics["max_similarity"] = float(np.max(off_diagonal_values))
            metrics["min_similarity"] = float(np.min(off_diagonal_values))
            metrics["std_similarity"] = float(np.std(off_diagonal_values))

            # If we have exactly 2 prompts, get the direct similarity
            if similarity_df.shape[0] == 2:
                metrics["prompt_similarity"] = float(similarity_df.iloc[0, 1])

        return metrics


# Convenience function for easy integration
@with_logger
def create_embeddings_analyzer(config: Dict[str, Any]) -> Optional[EmbeddingsAnalyzer]:
    """
    Create an embeddings analyzer from configuration.

    Args:
        config: Configuration dictionary

    Returns:
        EmbeddingsAnalyzer instance or None if not configured
    """
    if "embeddings" not in config:
        logger.info("No embeddings configuration found")
        return None

    embeddings_config = config["embeddings"]

    if not embeddings_config.get("enabled", False):
        logger.info("Embeddings analysis is disabled in configuration")
        return None

    logger.info("Creating embeddings analyzer from configuration")

    # Extract configuration parameters
    model = embeddings_config.get("model", "bedrock-cohere-embed-eng-v3")
    output_formats = embeddings_config.get("output_format", ["png"])

    # Ensure output_formats is a list
    if isinstance(output_formats, str):
        output_formats = [output_formats]

    return EmbeddingsAnalyzer(
        model=model,
        output_formats=output_formats,
        **embeddings_config.get("kwargs", {}),
    )
