"""
Unit tests for the OpenAIEmbeddings class.
"""

import os
import time
from unittest import mock
import pytest
import numpy as np
import pandas as pd

# Import the modules we need to test
from src.embeddings.openai_embeddings import OpenAIEmbeddings
from src.embeddings.base import EmbeddingsInterface


class TestOpenAIEmbeddings:
    """Test cases for the OpenAIEmbeddings class."""

    def test_init(self):
        """Test initialisation of OpenAIEmbeddings."""
        test_input = ["Hello world", "Hi there"]
        test_model = "test-embedding-model"

        with mock.patch(
            "src.embeddings.openai_embeddings.OpenAI"
        ) as mock_openai_cls, mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ) as mock_configure, mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            embeddings = OpenAIEmbeddings(input=test_input, model=test_model)

            assert embeddings.input == test_input
            assert embeddings.model == test_model
            assert embeddings.kwargs == {}

            # Check that API info was loaded and environment configured
            mock_configure.assert_called_once_with(
                "https://custom-api.example.com", "test-api-key"
            )

            # Check that OpenAI client was initialised
            mock_openai_cls.assert_called_once_with(
                base_url="https://custom-api.example.com", api_key="test-api-key"
            )

    def test_init_with_kwargs(self):
        """Test initialisation with additional kwargs."""
        test_input = ["Test text"]

        with mock.patch("src.embeddings.openai_embeddings.OpenAI"), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            embeddings = OpenAIEmbeddings(
                input=test_input,
                model="custom-model",
                custom_param="value",
                another_param=123,
            )

            assert embeddings.kwargs == {"custom_param": "value", "another_param": 123}

    def test_init_failure(self):
        """Test initialisation failure when OpenAI client fails."""
        test_input = ["Test text"]

        with mock.patch(
            "src.embeddings.openai_embeddings.OpenAI",
            side_effect=Exception("API Error"),
        ), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            with pytest.raises(Exception, match="API Error"):
                OpenAIEmbeddings(input=test_input)

    def test_get_embeddings(self):
        """Test get_embeddings method."""
        test_input = ["Hello", "World"]
        mock_embeddings_response = mock.MagicMock()

        # Mock the OpenAI client
        mock_client = mock.MagicMock()
        mock_client.embeddings.create.return_value = mock_embeddings_response

        with mock.patch(
            "src.embeddings.openai_embeddings.OpenAI", return_value=mock_client
        ), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            embeddings = OpenAIEmbeddings(input=test_input, model="test-model")
            result = embeddings.get_embeddings()

            # Check that the client was called correctly
            mock_client.embeddings.create.assert_called_once_with(
                model="test-model", input=test_input
            )

            # Check that the result is returned
            assert result == mock_embeddings_response

    def test_get_embeddings_with_input_parameter(self):
        """Test get_embeddings method with input parameter."""
        original_input = ["Original"]
        new_input = ["New", "Input"]
        mock_embeddings_response = mock.MagicMock()

        # Mock the OpenAI client
        mock_client = mock.MagicMock()
        mock_client.embeddings.create.return_value = mock_embeddings_response

        with mock.patch(
            "src.embeddings.openai_embeddings.OpenAI", return_value=mock_client
        ), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            embeddings = OpenAIEmbeddings(input=original_input, model="test-model")
            result = embeddings.get_embeddings(new_input)

            # The implementation uses `input_ls = self.input or input_texts`
            # So when new_input is provided, it should use new_input
            mock_client.embeddings.create.assert_called_once_with(
                model="test-model", input=new_input
            )

    def test_similarity_matrix(self):
        """Test similarity_matrix method."""
        test_input = ["Hello", "World"]

        # Create mock embeddings data
        class MockEmbedding:
            def __init__(self, embedding):
                self.embedding = embedding

        mock_embeddings = mock.MagicMock()
        mock_embeddings.data = [
            MockEmbedding([1.0, 0.0, 0.0]),
            MockEmbedding([0.0, 1.0, 0.0]),
        ]

        # Mock cosine_similarity to return a known matrix
        expected_similarity = np.array([[1.0, 0.0], [0.0, 1.0]])

        with mock.patch("src.embeddings.openai_embeddings.OpenAI"), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ), mock.patch(
            "sklearn.metrics.pairwise.cosine_similarity",
            return_value=expected_similarity,
        ):

            embeddings = OpenAIEmbeddings(input=test_input)
            result = embeddings.similarity_matrix(mock_embeddings)

            # Check that result is a DataFrame with correct structure
            assert isinstance(result, pd.DataFrame)
            assert list(result.columns) == test_input
            assert list(result.index) == test_input
            assert result.shape == (2, 2)

    def test_similarity_heatmap(self):
        """Test similarity_heatmap method."""
        test_input = ["A", "B"]
        test_df = pd.DataFrame(
            [[1.0, 0.5], [0.5, 1.0]], columns=test_input, index=test_input
        )

        # Mock matplotlib components
        mock_fig = mock.MagicMock()
        mock_ax = mock.MagicMock()

        with mock.patch("src.embeddings.openai_embeddings.OpenAI"), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ), mock.patch(
            "matplotlib.pyplot.subplots", return_value=(mock_fig, mock_ax)
        ), mock.patch(
            "seaborn.heatmap", return_value=mock_ax
        ) as mock_heatmap:

            embeddings = OpenAIEmbeddings(input=test_input, model="test-model")
            fig, ax = embeddings.similarity_heatmap(test_df)

            # Check that heatmap was created with correct parameters
            mock_heatmap.assert_called_once_with(
                test_df, annot=True, cmap="coolwarm", vmin=-1, vmax=1, ax=mock_ax
            )

            # Check that title was set
            mock_ax.set_title.assert_called_once()
            title_call = mock_ax.set_title.call_args[0][0]
            assert "test-model" in title_call
            assert "Cosine Similarity Heatmap" in title_call

            # Check return values
            assert fig == mock_fig
            assert ax == mock_ax

    def test_generate_end_to_end(self):
        """Test the generate method (end-to-end pipeline)."""
        test_input = ["Hello", "World"]

        # Mock all the components
        mock_embeddings = mock.MagicMock()
        mock_similarity_df = pd.DataFrame([[1.0, 0.5], [0.5, 1.0]])
        mock_fig = mock.MagicMock()
        mock_ax = mock.MagicMock()

        with mock.patch("src.embeddings.openai_embeddings.OpenAI"), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ), mock.patch.object(
            OpenAIEmbeddings, "get_embeddings", return_value=mock_embeddings
        ) as mock_get_emb, mock.patch.object(
            OpenAIEmbeddings, "similarity_matrix", return_value=mock_similarity_df
        ) as mock_sim_matrix, mock.patch.object(
            OpenAIEmbeddings, "similarity_heatmap", return_value=(mock_fig, mock_ax)
        ) as mock_heatmap:

            embeddings = OpenAIEmbeddings(input=test_input)

            # Measure time for the timing assertion
            start_time = time.time()
            fig, ax = embeddings.generate(test_input)
            end_time = time.time()

            # Check that all methods were called in sequence
            mock_get_emb.assert_called_once_with(test_input)
            mock_sim_matrix.assert_called_once_with(mock_embeddings)
            mock_heatmap.assert_called_once_with(mock_similarity_df)

            # Check return values
            assert fig == mock_fig
            assert ax == mock_ax

            # Check that timing was logged (indirectly by ensuring no exceptions)
            assert end_time >= start_time

    def test_model_info(self):
        """Test model_info property."""
        test_input = ["Test"]
        test_model = "custom-embedding-model"
        test_kwargs = {"param1": "value1", "param2": 42}

        with mock.patch("src.embeddings.openai_embeddings.OpenAI"), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            embeddings = OpenAIEmbeddings(
                input=test_input, model=test_model, **test_kwargs
            )

            info = embeddings.model_info

            # Check the model info structure
            assert info["name"] == test_model
            assert info["vendor"] == "OpenAI"
            assert info["param1"] == "value1"
            assert info["param2"] == 42

    def test_inheritance(self):
        """Test that OpenAIEmbeddings properly inherits from EmbeddingsInterface."""
        with mock.patch("src.embeddings.openai_embeddings.OpenAI"), mock.patch.object(
            OpenAIEmbeddings, "configure_environment"
        ), mock.patch.object(
            OpenAIEmbeddings,
            "load_api_info",
            return_value=("https://custom-api.example.com", "test-api-key"),
        ):

            embeddings = OpenAIEmbeddings(input=["test"])

            # Check that it's an instance of the base interface
            assert isinstance(embeddings, EmbeddingsInterface)

            # Check that it has the required methods
            assert hasattr(embeddings, "generate")
            assert hasattr(embeddings, "model_info")
            assert hasattr(embeddings, "load_api_info")
            assert hasattr(embeddings, "configure_environment")
