"""
Integration tests for the embeddings functionality.

This module tests the complete embeddings workflow including:
- Registry integration
- OpenAI embeddings generation
- Embeddings analyzer workflow
- Integration with experiment pipeline
- File output and configuration handling
"""

import os
import tempfile
import pathlib
import json
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from src.core.experiment import ExperimentRunner
from src.core.registry import embeddings_registry
from src.embeddings.openai_embeddings import OpenAIEmbeddings
from src.embeddings.embeddings_analyzer import (
    EmbeddingsAnalyzer,
    create_embeddings_analyzer,
)
from src.llm.dummy_llm import DummyLLM
from src.tasks.tamper_detection.task_handler import TamperDetectionTask
from src.prompt_optimisation.psao_optuna_optimiser import PSAOOptunaPromptOptimiser


class TestEmbeddingsRegistryIntegration:
    """Integration tests for embeddings registry functionality."""

    def test_embeddings_registry_registration(self):
        """Test that OpenAI embeddings are properly registered."""
        # Check that OpenAI_Embeddings is registered
        registered_components = embeddings_registry.list_registered()
        assert "OpenAI_Embeddings" in registered_components

        # Test that we can retrieve the component
        embeddings_cls = embeddings_registry.get("OpenAI_Embeddings")
        assert embeddings_cls == OpenAIEmbeddings

    @patch("src.embeddings.openai_embeddings.OpenAI")
    @patch.object(OpenAIEmbeddings, "configure_environment")
    @patch.object(
        OpenAIEmbeddings,
        "load_api_info",
        return_value=("https://test-api.com", "test-key"),
    )
    def test_embeddings_registry_creation(
        self, mock_load_api, mock_configure, mock_openai
    ):
        """Test creating embeddings instance through registry."""
        test_input = ["Test prompt 1", "Test prompt 2"]
        test_model = "test-embedding-model"

        # Create instance through registry
        embeddings = embeddings_registry.create(
            "OpenAI_Embeddings", input=test_input, model=test_model
        )

        # Verify instance properties
        assert isinstance(embeddings, OpenAIEmbeddings)
        assert embeddings.input == test_input
        assert embeddings.model == test_model

        # Verify API setup was called
        mock_load_api.assert_called_once()
        mock_configure.assert_called_once_with("https://test-api.com", "test-key")
        mock_openai.assert_called_once_with(
            base_url="https://test-api.com", api_key="test-key"
        )


class TestOpenAIEmbeddingsIntegration:
    """Integration tests for OpenAI embeddings functionality."""

    def setup_method(self):
        """Set up test fixtures."""
        self.test_prompts = [
            "You are an expert in detecting tampering in images.",
            "You are a specialist in identifying image tampering and fraud detection.",
        ]
        self.test_model = "bedrock-cohere-embed-eng-v3"

    @patch("src.embeddings.openai_embeddings.OpenAI")
    @patch.object(OpenAIEmbeddings, "configure_environment")
    @patch.object(
        OpenAIEmbeddings,
        "load_api_info",
        return_value=("https://test-api.com", "test-key"),
    )
    def test_embeddings_generation_workflow(
        self, mock_load_api, mock_configure, mock_openai
    ):
        """Test the complete embeddings generation workflow."""
        # Mock OpenAI client and responses
        mock_client = MagicMock()
        mock_openai.return_value = mock_client

        # Mock embeddings response
        class MockEmbedding:
            def __init__(self, embedding):
                self.embedding = embedding

        mock_embeddings_response = MagicMock()
        mock_embeddings_response.data = [
            MockEmbedding([1.0, 0.0, 0.0]),
            MockEmbedding([0.8, 0.6, 0.0]),
        ]
        mock_client.embeddings.create.return_value = mock_embeddings_response

        # Mock matplotlib components
        mock_fig = MagicMock()
        mock_ax = MagicMock()

        with patch(
            "matplotlib.pyplot.subplots", return_value=(mock_fig, mock_ax)
        ), patch("seaborn.heatmap", return_value=mock_ax), patch(
            "sklearn.metrics.pairwise.cosine_similarity",
            return_value=np.array([[1.0, 0.8], [0.8, 1.0]]),
        ):

            # Create embeddings instance
            embeddings = OpenAIEmbeddings(
                input=self.test_prompts, model=self.test_model
            )

            # Test the complete workflow
            fig, ax = embeddings.generate(self.test_prompts)

            # Verify API calls
            mock_client.embeddings.create.assert_called_with(
                model=self.test_model, input=self.test_prompts
            )

            # Verify return values
            assert fig == mock_fig
            assert ax == mock_ax

    @patch("src.embeddings.openai_embeddings.OpenAI")
    @patch.object(OpenAIEmbeddings, "configure_environment")
    @patch.object(
        OpenAIEmbeddings,
        "load_api_info",
        return_value=("https://test-api.com", "test-key"),
    )
    def test_similarity_matrix_creation(
        self, mock_load_api, mock_configure, mock_openai
    ):
        """Test similarity matrix creation with real data structures."""
        mock_client = MagicMock()
        mock_openai.return_value = mock_client

        # Create realistic mock embeddings
        class MockEmbedding:
            def __init__(self, embedding):
                self.embedding = embedding

        mock_embeddings_response = MagicMock()
        mock_embeddings_response.data = [
            MockEmbedding([1.0, 0.0, 0.0]),
            MockEmbedding([0.0, 1.0, 0.0]),
        ]

        # Mock cosine similarity to return predictable results
        expected_similarity = np.array([[1.0, 0.0], [0.0, 1.0]])

        with patch(
            "sklearn.metrics.pairwise.cosine_similarity",
            return_value=expected_similarity,
        ):
            embeddings = OpenAIEmbeddings(
                input=self.test_prompts, model=self.test_model
            )
            similarity_df = embeddings.similarity_matrix(mock_embeddings_response)

            # Verify DataFrame structure
            assert isinstance(similarity_df, pd.DataFrame)
            assert list(similarity_df.columns) == self.test_prompts
            assert list(similarity_df.index) == self.test_prompts
            assert similarity_df.shape == (2, 2)

            # Verify similarity values
            assert similarity_df.iloc[0, 0] == 1.0
            assert similarity_df.iloc[1, 1] == 1.0
            assert similarity_df.iloc[0, 1] == 0.0
            assert similarity_df.iloc[1, 0] == 0.0


class TestEmbeddingsAnalyzerIntegration:
    """Integration tests for embeddings analyzer functionality."""

    def setup_method(self):
        """Set up test fixtures."""
        self.temp_dir = tempfile.TemporaryDirectory()
        self.output_dir = pathlib.Path(self.temp_dir.name)

        self.test_results = {
            "base_results": {"accuracy": 0.8},
            "optimized_results": {"accuracy": 0.9},
            "prompt_optimization_results": {
                "base_prompt": "You are an expert in detecting tampering.",
                "optimised_prompt": "You are a highly skilled specialist in identifying image tampering.",
            },
        }

    def teardown_method(self):
        """Clean up test fixtures."""
        self.temp_dir.cleanup()

    @patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    def test_embeddings_analyzer_workflow(self, mock_registry):
        """Test the complete embeddings analyzer workflow."""
        # Mock the embeddings client
        mock_client = MagicMock()
        mock_fig = MagicMock()
        mock_ax = MagicMock()
        mock_embeddings = MagicMock()
        mock_similarity_df = pd.DataFrame(
            [[1.0, 0.85], [0.85, 1.0]],
            columns=["Base Prompt", "Optimized Prompt"],
            index=["Base Prompt", "Optimized Prompt"],
        )

        mock_client.generate.return_value = (mock_fig, mock_ax)
        mock_client.get_embeddings.return_value = mock_embeddings
        mock_client.similarity_matrix.return_value = mock_similarity_df
        mock_registry.create.return_value = mock_client

        # Create analyzer
        analyzer = EmbeddingsAnalyzer(model="test-model", output_formats=["png", "pdf"])

        # Test complete workflow
        result = analyzer.analyze_experiment_results(
            base_results=self.test_results["base_results"],
            optimized_results=self.test_results["optimized_results"],
            prompt_optimization_results=self.test_results[
                "prompt_optimization_results"
            ],
            output_dir=self.output_dir,
            experiment_name="test_experiment",
        )

        # Verify results structure
        assert "prompts" in result
        assert "labels" in result
        assert "similarity_matrix" in result
        assert "similarity_metrics" in result
        assert "saved_files" in result
        assert "figure" in result
        assert "axes" in result

        # Verify prompts extraction
        assert len(result["prompts"]) == 2
        assert result["prompts"][0] == "You are an expert in detecting tampering."
        assert (
            result["prompts"][1]
            == "You are a highly skilled specialist in identifying image tampering."
        )

        # Verify labels
        assert result["labels"] == ["Base Prompt", "Optimized Prompt"]

        # Verify similarity metrics
        assert "mean_similarity" in result["similarity_metrics"]
        assert "prompt_similarity" in result["similarity_metrics"]
        assert result["similarity_metrics"]["prompt_similarity"] == 0.85

    def test_embeddings_analyzer_file_output(self):
        """Test that embeddings analyzer properly saves files."""
        analyzer = EmbeddingsAnalyzer(output_formats=["png"])

        # Create mock figure and similarity data
        mock_fig = MagicMock()
        mock_similarity_df = pd.DataFrame([[1.0, 0.8], [0.8, 1.0]])

        # Test file saving
        with patch("pathlib.Path.mkdir") as mock_mkdir, patch(
            "pandas.DataFrame.to_csv"
        ) as mock_to_csv, patch("pandas.DataFrame.to_json") as mock_to_json:

            saved_files = analyzer.save_analysis_results(
                mock_fig, mock_similarity_df, self.output_dir, "test_experiment"
            )

            # Verify directory creation
            mock_mkdir.assert_called()

            # Verify file saving calls
            mock_fig.savefig.assert_called_once()
            mock_to_csv.assert_called_once()
            mock_to_json.assert_called_once()

            # Verify saved files structure
            assert "heatmap_png" in saved_files
            assert "similarity_csv" in saved_files
            assert "similarity_json" in saved_files

    def test_create_embeddings_analyzer_from_config(self):
        """Test creating embeddings analyzer from configuration."""
        config = {
            "embeddings": {
                "enabled": True,
                "model": "custom-embedding-model",
                "output_format": ["png", "pdf"],
                "kwargs": {"custom_param": "value"},
            }
        }

        analyzer = create_embeddings_analyzer(config)

        assert analyzer is not None
        assert isinstance(analyzer, EmbeddingsAnalyzer)
        assert analyzer.model == "custom-embedding-model"
        assert analyzer.output_formats == ["png", "pdf"]
        assert analyzer.kwargs == {"custom_param": "value"}

    def test_create_embeddings_analyzer_disabled(self):
        """Test that analyzer is not created when disabled."""
        config = {"embeddings": {"enabled": False}}

        analyzer = create_embeddings_analyzer(config)
        assert analyzer is None

    def test_create_embeddings_analyzer_no_config(self):
        """Test that analyzer is not created when no config exists."""
        config = {}

        analyzer = create_embeddings_analyzer(config)
        assert analyzer is None


class TestEmbeddingsExperimentIntegration:
    """Integration tests for embeddings with experiment pipeline."""

    def setup_method(self):
        """Set up test fixtures."""
        self.temp_dir = tempfile.TemporaryDirectory()
        self.output_dir = pathlib.Path(self.temp_dir.name)

        # Create test configuration with embeddings enabled
        self.config = {
            "project": {
                "name": "embeddings_integration_test",
                "version": "0.1.0",
            },
            "paths": {
                "data": "data",
                "output": str(self.output_dir),
            },
            "llm": {
                "type": "DummyLLM",
                "version_name": "dummy",
                "response": "No",
            },
            "embeddings": {
                "enabled": True,
                "default": "OpenAI_Embeddings",
                "OpenAI_Embeddings": {
                    "model": "test-embedding-model",
                    "output_format": ["png"],
                },
            },
            "prompt_optimiser": {
                "type": "tone",
                "r_seed": 42,
                "optuna_n_trials": 2,
            },
            "task": {
                "type": "tamper_detection",
                "font_semantics": "font",
                "num_images": 3,
            },
            "output": {
                "embeddings_heatmap": str(self.output_dir / "test_heatmap.png"),
                "embeddings_similarity": str(self.output_dir / "test_similarity.csv"),
            },
        }

    def teardown_method(self):
        """Clean up test fixtures."""
        self.temp_dir.cleanup()

    @patch("src.embeddings.embeddings_analyzer.embeddings_registry")
    @patch("src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler")
    @patch("optuna.create_study")
    def test_experiment_with_embeddings_integration(
        self, mock_optuna, mock_data_handler, mock_embeddings_registry
    ):
        """Test running a complete experiment with embeddings integration."""
        # Mock data handler
        mock_handler = mock_data_handler.return_value
        mock_handler.get_size.return_value = 3
        mock_handler.get_data.return_value = (
            "base64_image",
            "jpeg",
            "base64_image",
            "jpeg",
        )

        # Mock Optuna study
        mock_study = MagicMock()
        mock_study.best_params = {"tone_w_0": 10, "tone_w_1": 7}
        mock_optuna.return_value = mock_study

        # Mock embeddings components
        mock_embeddings_client = MagicMock()
        mock_fig = MagicMock()
        mock_ax = MagicMock()
        mock_similarity_df = pd.DataFrame([[1.0, 0.9], [0.9, 1.0]])

        mock_embeddings_client.generate.return_value = (mock_fig, mock_ax)
        mock_embeddings_client.get_embeddings.return_value = MagicMock()
        mock_embeddings_client.similarity_matrix.return_value = mock_similarity_df
        mock_embeddings_registry.create.return_value = mock_embeddings_client

        # Create test analyzer with test model
        test_analyzer = EmbeddingsAnalyzer(model="test-embedding-model")

        # Mock the analyze_experiment_results method to return expected results
        mock_analysis_results = {
            "prompts": [
                "You are an expert in detecting tampering.",
                "Optimized prompt",
            ],
            "labels": ["Base Prompt", "Optimized Prompt"],
            "similarity_matrix": mock_similarity_df,
            "similarity_metrics": {"prompt_similarity": 0.9, "mean_similarity": 0.95},
            "saved_files": {
                "heatmap_png": "test_heatmap.png",
                "similarity_csv": "test_similarity.csv",
            },
            "figure": mock_fig,
            "axes": mock_ax,
        }
        test_analyzer.analyze_experiment_results = MagicMock(
            return_value=mock_analysis_results
        )

        # Create experiment runner
        runner = ExperimentRunner(
            "embeddings_test", config=self.config, output_dir=self.output_dir
        )

        def setup_llm(config):
            """Set up the LLM."""
            llm = DummyLLM(**config["llm"])
            return llm

        def setup_task(config):
            """Set up the task."""
            return TamperDetectionTask(
                **config["task"],
                prompt_msg_template="You are an expert in detecting tampering."
            )

        def init_embeddings(config):
            """Initialise embeddings analyzer."""
            return test_analyzer

        def optimise_prompt(llm, optimiser, task):
            """Optimize the prompt."""
            base_prompt = "You are an expert in detecting tampering."

            optimised_prompt = (
                "You are a highly skilled specialist in identifying image tampering."
            )
            return {
                "base_prompt": base_prompt,
                "optimised_prompt": optimised_prompt,
            }

        def run_task(llm, optimiser, task):
            """Run the task."""
            with patch.object(task, "run") as mock_run:
                results_df = pd.DataFrame(
                    {
                        "img_id": list(range(3)),
                        "response_original": ["No"] * 3,
                        "response_tampered": ["Yes"] * 3,
                    }
                )
                mock_run.return_value = (results_df, 1.0)

                results_df, eval_score = task.run(llm, optimiser)
                return {
                    "results_df": results_df,
                    "eval_score": eval_score,
                }

        def generate_embeddings_analysis(
            llm, optimiser, task, init_embeddings, optimise_prompt, run_task
        ):
            """Generate embeddings analysis."""
            embeddings_analyzer = init_embeddings
            prompt_results = optimise_prompt
            task_results = run_task

            if embeddings_analyzer is None:
                return {"embeddings_skipped": True}

            # Use the actual results from previous steps
            base_results = {"eval_score": task_results.get("eval_score", 0.0)}
            optimized_results = {"eval_score": task_results.get("eval_score", 0.0)}
            prompt_optimization_results = prompt_results

            return embeddings_analyzer.analyze_experiment_results(
                base_results=base_results,
                optimized_results=optimized_results,
                prompt_optimization_results=prompt_optimization_results,
                output_dir=self.output_dir,
                experiment_name="embeddings_test",
                config=self.config,
            )

        def setup_optimiser(config):
            """Set up the optimiser."""
            return PSAOOptunaPromptOptimiser(
                psao_intro_prompt="dummy",
                psao_struct_ann="dummy annotation",
                **config["prompt_optimiser"]
            )

        # Add experiment steps
        runner.add_step("setup_llm", setup_llm, config=self.config)
        runner.add_step("setup_optimiser", setup_optimiser, config=self.config)
        runner.add_step("setup_task", setup_task, config=self.config)
        runner.add_step("init_embeddings", init_embeddings, config=self.config)
        runner.add_step("optimise_prompt", optimise_prompt)
        runner.add_step("run_task", run_task)
        runner.add_step("generate_embeddings_analysis", generate_embeddings_analysis)

        # Mock file saving to avoid actual file operations
        with patch.object(runner, "_save_results"), patch("pathlib.Path.mkdir"), patch(
            "pandas.DataFrame.to_csv"
        ), patch("pandas.DataFrame.to_json"):

            # Run the experiment
            results = runner.run()

            # Verify all steps completed
            assert "setup_llm" in results
            assert "setup_optimiser" in results
            assert "setup_task" in results
            assert "init_embeddings" in results
            assert "optimise_prompt" in results
            assert "run_task" in results
            assert "generate_embeddings_analysis" in results

            # Verify embeddings analyzer was created
            embeddings_analyzer = results["init_embeddings"]
            assert isinstance(embeddings_analyzer, EmbeddingsAnalyzer)
            assert embeddings_analyzer.model == "test-embedding-model"

            # Verify prompt optimization results
            prompt_results = results["optimise_prompt"]
            assert "base_prompt" in prompt_results
            assert "optimised_prompt" in prompt_results

            # Verify embeddings analysis results
            embeddings_results = results["generate_embeddings_analysis"]
            assert "prompts" in embeddings_results
            assert "similarity_matrix" in embeddings_results
            assert "similarity_metrics" in embeddings_results

            # Verify embeddings registry was mocked (not called since we return test analyzer directly)
            # This ensures the mock is in place for other tests that might need it
            assert mock_embeddings_registry is not None


class TestEmbeddingsConfigurationIntegration:
    """Integration tests for embeddings configuration handling."""

    def test_embeddings_config_loading(self):
        """Test loading embeddings configuration from YAML-like structure."""
        config = {
            "embeddings": {
                "enabled": True,
                "default": "OpenAI_Embeddings",
                "OpenAI_Embeddings": {
                    "model": "bedrock-cohere-embed-eng-v3",
                    "output_format": ["png", "pdf"],
                },
            }
        }

        # Test configuration parsing
        embeddings_config = config.get("embeddings", {})
        assert embeddings_config.get("enabled") is True
        assert embeddings_config.get("default") == "OpenAI_Embeddings"

        model_config = embeddings_config.get("OpenAI_Embeddings", {})
        assert model_config.get("model") == "bedrock-cohere-embed-eng-v3"
        assert model_config.get("output_format") == ["png", "pdf"]

    def test_embeddings_output_path_configuration(self):
        """Test embeddings output path configuration."""
        config = {
            "output": {
                "embeddings_heatmap": "custom/path/heatmap.png",
                "embeddings_similarity": "custom/path/similarity.csv",
            }
        }

        # Test output path extraction
        output_config = config.get("output", {})
        assert output_config.get("embeddings_heatmap") == "custom/path/heatmap.png"
        assert (
            output_config.get("embeddings_similarity") == "custom/path/similarity.csv"
        )

    def test_embeddings_config_validation(self):
        """Test embeddings configuration validation scenarios."""
        # Test valid configuration
        valid_config = {
            "embeddings": {
                "enabled": True,
                "default": "OpenAI_Embeddings",
                "OpenAI_Embeddings": {"model": "test-model"},
            }
        }

        analyzer = create_embeddings_analyzer(valid_config)
        assert analyzer is not None

        # Test disabled configuration
        disabled_config = {"embeddings": {"enabled": False}}

        analyzer = create_embeddings_analyzer(disabled_config)
        assert analyzer is None

        # Test missing configuration
        empty_config = {}

        analyzer = create_embeddings_analyzer(empty_config)
        assert analyzer is None


class TestEmbeddingsEndToEndWorkflow:
    """End-to-end integration tests for embeddings workflow."""

    def setup_method(self):
        """Set up test fixtures."""
        self.temp_dir = tempfile.TemporaryDirectory()
        self.output_dir = pathlib.Path(self.temp_dir.name)

    def teardown_method(self):
        """Clean up test fixtures."""
        self.temp_dir.cleanup()

    @patch("src.embeddings.openai_embeddings.OpenAI")
    @patch.object(OpenAIEmbeddings, "configure_environment")
    @patch.object(
        OpenAIEmbeddings,
        "load_api_info",
        return_value=("https://test-api.com", "test-key"),
    )
    def test_end_to_end_embeddings_workflow(
        self, mock_load_api, mock_configure, mock_openai
    ):
        """Test complete end-to-end embeddings workflow with mocked API calls."""
        # Mock OpenAI client
        mock_client = MagicMock()
        mock_openai.return_value = mock_client

        # Mock embeddings API response
        class MockEmbedding:
            def __init__(self, embedding):
                self.embedding = embedding

        mock_embeddings_response = MagicMock()
        mock_embeddings_response.data = [
            MockEmbedding([1.0, 0.0, 0.0]),
            MockEmbedding([0.9, 0.1, 0.0]),
        ]
        mock_client.embeddings.create.return_value = mock_embeddings_response

        # Mock matplotlib components
        mock_fig = MagicMock()
        mock_ax = MagicMock()

        with patch(
            "matplotlib.pyplot.subplots", return_value=(mock_fig, mock_ax)
        ), patch("seaborn.heatmap", return_value=mock_ax), patch(
            "sklearn.metrics.pairwise.cosine_similarity",
            return_value=np.array([[1.0, 0.95], [0.95, 1.0]]),
        ), patch(
            "pathlib.Path.mkdir"
        ), patch(
            "pandas.DataFrame.to_csv"
        ), patch(
            "pandas.DataFrame.to_json"
        ):

            # Test data
            test_prompts = [
                "You are an expert in detecting tampering in images.",
                "You are a highly skilled specialist in identifying image tampering and fraud.",
            ]

            # Step 1: Create embeddings through registry
            embeddings = embeddings_registry.create(
                "OpenAI_Embeddings",
                input=test_prompts,
                model="bedrock-cohere-embed-eng-v3",
            )

            # Step 2: Generate embeddings and heatmap
            fig, ax = embeddings.generate(test_prompts)

            # Step 3: Create analyzer and run complete workflow
            analyzer = EmbeddingsAnalyzer(
                model="bedrock-cohere-embed-eng-v3", output_formats=["png"]
            )

            # Mock prompt optimization results
            prompt_optimization_results = {
                "base_prompt": test_prompts[0],
                "optimised_prompt": test_prompts[1],
            }

            # Step 4: Run complete analysis
            results = analyzer.analyze_experiment_results(
                prompt_optimization_results=prompt_optimization_results,
                output_dir=self.output_dir,
                experiment_name="end_to_end_test",
            )

            # Verify complete workflow
            assert "prompts" in results
            assert "similarity_matrix" in results
            assert "similarity_metrics" in results
            assert "saved_files" in results

            # Verify API calls
            mock_client.embeddings.create.assert_called()

            # Verify similarity metrics
            assert "prompt_similarity" in results["similarity_metrics"]
            # Use approximate comparison for floating point values
            assert abs(results["similarity_metrics"]["prompt_similarity"] - 0.95) < 0.1

            # Verify file saving was attempted
            mock_fig.savefig.assert_called()
