"""
Unit tests for the EmbeddingsAnalyzer class.
"""

import os
import pathlib
from unittest import mock
import pytest
import pandas as pd
import numpy as np

# Import the modules we need to test
from src.embeddings.embeddings_analyzer import (
    EmbeddingsAnalyzer,
    create_embeddings_analyzer,
)


class TestEmbeddingsAnalyzer:
    """Test cases for the EmbeddingsAnalyzer class."""

    def test_init_default(self):
        """Test initialisation with default parameters."""
        analyzer = EmbeddingsAnalyzer()

        assert analyzer.model == "bedrock-cohere-embed-eng-v3"
        assert analyzer.output_formats == ["png"]
        assert analyzer.kwargs == {}
        assert analyzer.embeddings_client is None

    def test_init_custom(self):
        """Test initialisation with custom parameters."""
        analyzer = EmbeddingsAnalyzer(
            model="custom-model", output_formats=["png", "pdf"], custom_param="value"
        )

        assert analyzer.model == "custom-model"
        assert analyzer.output_formats == ["png", "pdf"]
        assert analyzer.kwargs == {"custom_param": "value"}

    def test_extract_prompts_from_results_with_optimization_results(self):
        """Test extracting prompts from optimization results."""
        analyzer = EmbeddingsAnalyzer()

        prompt_optimization_results = {
            "base_prompt": "This is the base prompt",
            "optimised_prompt": "This is the optimized prompt",
        }

        prompts = analyzer.extract_prompts_from_results(
            prompt_optimization_results=prompt_optimization_results
        )

        assert len(prompts) == 2
        assert prompts[0] == "This is the base prompt"
        assert prompts[1] == "This is the optimized prompt"

    def test_extract_prompts_duplicate_handling(self):
        """Test that duplicate prompts are not added."""
        analyzer = EmbeddingsAnalyzer()

        prompt_optimization_results = {
            "base_prompt": "Same prompt",
            "optimised_prompt": "Same prompt",  # Duplicate
        }

        prompts = analyzer.extract_prompts_from_results(
            prompt_optimization_results=prompt_optimization_results
        )

        # Should only have one prompt since they're identical
        assert len(prompts) == 1
        assert prompts[0] == "Same prompt"

    def test_extract_prompts_no_results(self):
        """Test extracting prompts when no results are provided."""
        analyzer = EmbeddingsAnalyzer()

        prompts = analyzer.extract_prompts_from_results()

        assert prompts == []

    def test_extract_prompts_empty_optimization_results(self):
        """Test extracting prompts from empty optimization results."""
        analyzer = EmbeddingsAnalyzer()

        prompts = analyzer.extract_prompts_from_results(prompt_optimization_results={})

        assert prompts == []

    @mock.patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    def test_generate_embeddings_analysis(self, mock_registry):
        """Test generating embeddings analysis."""
        analyzer = EmbeddingsAnalyzer(model="test-model")

        # Mock the embeddings client
        mock_client = mock.MagicMock()
        mock_fig = mock.MagicMock()
        mock_ax = mock.MagicMock()
        mock_embeddings = mock.MagicMock()
        mock_similarity_df = pd.DataFrame(
            [[1.0, 0.8], [0.8, 1.0]], columns=["A", "B"], index=["A", "B"]
        )

        mock_client.generate.return_value = (mock_fig, mock_ax)
        mock_client.get_embeddings.return_value = mock_embeddings
        mock_client.similarity_matrix.return_value = mock_similarity_df
        mock_registry.create.return_value = mock_client

        prompts = ["First prompt", "Second prompt"]
        labels = ["Label 1", "Label 2"]

        fig, ax, similarity_df = analyzer.generate_embeddings_analysis(prompts, labels)

        # Check that registry was called correctly
        mock_registry.create.assert_called_once_with(
            "OpenAI_Embeddings", input=prompts, model="test-model"
        )

        # Check that client methods were called
        mock_client.generate.assert_called_once_with(prompts)
        mock_client.get_embeddings.assert_called_once_with(prompts)
        mock_client.similarity_matrix.assert_called_once_with(mock_embeddings)

        # Check that labels were applied
        mock_ax.set_xticklabels.assert_called_once_with(labels, rotation=45, ha="right")
        mock_ax.set_yticklabels.assert_called_once_with(labels, rotation=0)

        # Check return values
        assert fig == mock_fig
        assert ax == mock_ax
        assert similarity_df.equals(mock_similarity_df)

    @mock.patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    def test_generate_embeddings_analysis_single_prompt(self, mock_registry):
        """Test generating analysis with single prompt (should duplicate)."""
        analyzer = EmbeddingsAnalyzer()

        mock_client = mock.MagicMock()
        mock_fig = mock.MagicMock()
        mock_ax = mock.MagicMock()
        mock_embeddings = mock.MagicMock()
        mock_similarity_df = pd.DataFrame(
            [[1.0, 1.0], [1.0, 1.0]], columns=["A", "A"], index=["A", "A"]
        )

        mock_client.generate.return_value = (mock_fig, mock_ax)
        mock_client.get_embeddings.return_value = mock_embeddings
        mock_client.similarity_matrix.return_value = mock_similarity_df
        mock_registry.create.return_value = mock_client

        prompts = ["Single prompt"]

        fig, ax, similarity_df = analyzer.generate_embeddings_analysis(prompts)

        # Should have duplicated the prompt for comparison
        expected_prompts = ["Single prompt", "Single prompt"]
        mock_registry.create.assert_called_once_with(
            "OpenAI_Embeddings",
            input=expected_prompts,
            model="bedrock-cohere-embed-eng-v3",
        )

        # Check return values
        assert fig == mock_fig
        assert ax == mock_ax
        assert similarity_df.equals(mock_similarity_df)

    @mock.patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    def test_generate_embeddings_analysis_reuse_client(self, mock_registry):
        """Test that existing client is reused when available."""
        analyzer = EmbeddingsAnalyzer()

        # Set up existing client
        existing_client = mock.MagicMock()
        mock_fig = mock.MagicMock()
        mock_ax = mock.MagicMock()
        mock_embeddings = mock.MagicMock()
        mock_similarity_df = pd.DataFrame(
            [[1.0, 1.0], [1.0, 1.0]], columns=["A", "A"], index=["A", "A"]
        )

        existing_client.generate.return_value = (mock_fig, mock_ax)
        existing_client.get_embeddings.return_value = mock_embeddings
        existing_client.similarity_matrix.return_value = mock_similarity_df
        analyzer.embeddings_client = existing_client

        prompts = ["Test prompt"]
        fig, ax, similarity_df = analyzer.generate_embeddings_analysis(prompts)

        # Should not create new client
        mock_registry.create.assert_not_called()

        # Should update existing client's input - single prompts get duplicated
        expected_prompts = ["Test prompt", "Test prompt"]
        assert existing_client.input == expected_prompts

        # Check return values
        assert fig == mock_fig
        assert ax == mock_ax
        assert similarity_df.equals(mock_similarity_df)

    def test_save_analysis_results_default_paths(self):
        """Test saving analysis results with default paths."""
        analyzer = EmbeddingsAnalyzer(output_formats=["png", "pdf"])

        mock_fig = mock.MagicMock()
        mock_similarity_df = pd.DataFrame([[1.0, 0.5], [0.5, 1.0]])
        output_dir = pathlib.Path("test_output")
        experiment_name = "test_experiment"

        with mock.patch("pathlib.Path.mkdir") as mock_mkdir, mock.patch(
            "pandas.DataFrame.to_csv"
        ) as mock_to_csv, mock.patch("pandas.DataFrame.to_json") as mock_to_json:

            saved_files = analyzer.save_analysis_results(
                mock_fig, mock_similarity_df, output_dir, experiment_name
            )

            # Check that directory was created
            mock_mkdir.assert_called()

            # Check that figure was saved in both formats
            assert mock_fig.savefig.call_count == 2

            # Check that CSV and JSON were saved
            assert "similarity_csv" in saved_files
            assert "similarity_json" in saved_files
            assert "heatmap_png" in saved_files
            assert "heatmap_pdf" in saved_files

            # Verify the actual method calls
            mock_to_csv.assert_called_once()
            mock_to_json.assert_called_once()

    def test_save_analysis_results_configured_paths(self):
        """Test saving analysis results with configured paths."""
        analyzer = EmbeddingsAnalyzer()

        mock_fig = mock.MagicMock()
        mock_similarity_df = pd.DataFrame([[1.0, 0.5], [0.5, 1.0]])

        config = {
            "output": {
                "embeddings_heatmap": "custom/heatmap.png",
                "embeddings_similarity": "custom/similarity.csv",
            }
        }

        with mock.patch("pathlib.Path.mkdir") as mock_mkdir, mock.patch(
            "pandas.DataFrame.to_csv"
        ) as mock_to_csv:

            saved_files = analyzer.save_analysis_results(
                mock_fig, mock_similarity_df, "output", "test", config
            )

            # Check that configured paths were used
            mock_fig.savefig.assert_called_once()
            mock_to_csv.assert_called_once()

            assert "heatmap" in saved_files
            assert "similarity_csv" in saved_files

    def test_save_analysis_results_error_handling(self):
        """Test error handling in save_analysis_results."""
        analyzer = EmbeddingsAnalyzer()

        mock_fig = mock.MagicMock()
        mock_fig.savefig.side_effect = Exception("Save error")
        mock_similarity_df = pd.DataFrame([[1.0]])

        # Should not raise exception, just log error
        saved_files = analyzer.save_analysis_results(
            mock_fig, mock_similarity_df, "output", "test"
        )

        # Should still return some results (CSV/JSON might succeed)
        assert isinstance(saved_files, dict)

    def test_calculate_similarity_metrics(self):
        """Test calculation of similarity metrics."""
        analyzer = EmbeddingsAnalyzer()

        # Create a test similarity matrix
        similarity_df = pd.DataFrame(
            [[1.0, 0.8, 0.6], [0.8, 1.0, 0.7], [0.6, 0.7, 1.0]],
            columns=["A", "B", "C"],
            index=["A", "B", "C"],
        )

        metrics = analyzer._calculate_similarity_metrics(similarity_df)

        # Check that metrics are calculated correctly
        assert "mean_similarity" in metrics
        assert "max_similarity" in metrics
        assert "min_similarity" in metrics
        assert "std_similarity" in metrics

        # Check values (off-diagonal: 0.8, 0.6, 0.8, 0.7, 0.6, 0.7)
        expected_mean = (0.8 + 0.6 + 0.8 + 0.7 + 0.6 + 0.7) / 6
        assert abs(metrics["mean_similarity"] - expected_mean) < 1e-10
        assert metrics["max_similarity"] == 0.8
        assert metrics["min_similarity"] == 0.6

    def test_calculate_similarity_metrics_two_prompts(self):
        """Test similarity metrics calculation for exactly two prompts."""
        analyzer = EmbeddingsAnalyzer()

        similarity_df = pd.DataFrame(
            [[1.0, 0.75], [0.75, 1.0]], columns=["A", "B"], index=["A", "B"]
        )

        metrics = analyzer._calculate_similarity_metrics(similarity_df)

        # Should include direct prompt similarity
        assert "prompt_similarity" in metrics
        assert metrics["prompt_similarity"] == 0.75

    @mock.patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    def test_analyze_experiment_results_success(self, mock_registry):
        """Test complete experiment analysis workflow."""
        analyzer = EmbeddingsAnalyzer()

        # Mock the embeddings client
        mock_client = mock.MagicMock()
        mock_fig = mock.MagicMock()
        mock_ax = mock.MagicMock()
        mock_similarity_df = pd.DataFrame([[1.0, 0.8], [0.8, 1.0]])

        mock_client.generate.return_value = (mock_fig, mock_ax)
        mock_client.get_embeddings.return_value = mock.MagicMock()
        mock_client.similarity_matrix.return_value = mock_similarity_df
        mock_registry.create.return_value = mock_client

        prompt_optimization_results = {
            "base_prompt": "Base prompt",
            "optimised_prompt": "Optimized prompt",
        }

        with mock.patch.object(
            analyzer, "save_analysis_results", return_value={"file": "path"}
        ):
            result = analyzer.analyze_experiment_results(
                prompt_optimization_results=prompt_optimization_results,
                output_dir="test_output",
                experiment_name="test_exp",
            )

            # Check that result contains expected keys
            assert "prompts" in result
            assert "labels" in result
            assert "similarity_matrix" in result
            assert "similarity_metrics" in result
            assert "saved_files" in result
            assert "figure" in result
            assert "axes" in result

            # Check prompts and labels
            assert len(result["prompts"]) == 2
            assert result["labels"] == ["Base Prompt", "Optimized Prompt"]

    def test_analyze_experiment_results_no_prompts(self):
        """Test experiment analysis when no prompts are found."""
        analyzer = EmbeddingsAnalyzer()

        result = analyzer.analyze_experiment_results()

        # Should return error
        assert "error" in result
        assert result["error"] == "No prompts found for analysis"

    @mock.patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    def test_analyze_experiment_results_exception(self, mock_registry):
        """Test experiment analysis error handling."""
        analyzer = EmbeddingsAnalyzer()

        # Make registry raise an exception
        mock_registry.create.side_effect = Exception("Registry error")

        prompt_optimization_results = {"base_prompt": "Base prompt"}

        result = analyzer.analyze_experiment_results(
            prompt_optimization_results=prompt_optimization_results
        )

        # Should return error
        assert "error" in result
        assert "Registry error" in result["error"]


class TestCreateEmbeddingsAnalyzer:
    """Test cases for the create_embeddings_analyzer function."""

    def test_create_embeddings_analyzer_enabled(self):
        """Test creating analyzer when embeddings are enabled."""
        config = {
            "embeddings": {
                "enabled": True,
                "model": "custom-model",
                "output_format": ["png", "pdf"],
                "kwargs": {"param": "value"},
            }
        }

        analyzer = create_embeddings_analyzer(config)

        assert analyzer is not None
        assert isinstance(analyzer, EmbeddingsAnalyzer)
        assert analyzer.model == "custom-model"
        assert analyzer.output_formats == ["png", "pdf"]

    def test_create_embeddings_analyzer_disabled(self):
        """Test creating analyzer when embeddings are disabled."""
        config = {"embeddings": {"enabled": False}}

        analyzer = create_embeddings_analyzer(config)

        assert analyzer is None

    def test_create_embeddings_analyzer_no_config(self):
        """Test creating analyzer when no embeddings config exists."""
        config = {}

        analyzer = create_embeddings_analyzer(config)

        assert analyzer is None

    def test_create_embeddings_analyzer_string_format(self):
        """Test creating analyzer with string output format."""
        config = {
            "embeddings": {
                "enabled": True,
                "output_format": "pdf",  # String instead of list
            }
        }

        analyzer = create_embeddings_analyzer(config)

        assert analyzer is not None
        assert analyzer.output_formats == ["pdf"]

    def test_create_embeddings_analyzer_defaults(self):
        """Test creating analyzer with default values."""
        config = {"embeddings": {"enabled": True}}

        analyzer = create_embeddings_analyzer(config)

        assert analyzer is not None
        assert analyzer.model == "bedrock-cohere-embed-eng-v3"
        assert analyzer.output_formats == ["png"]
