"""
Tests for the tamper detection task.
"""

import base64
from unittest import mock

import pytest
import pandas as pd

from src.tasks.tamper_detection.task_handler import TamperDetectionTask
from src.tasks.tamper_detection.data_handler import TamperDetectionDataHandler
from src.tasks.tamper_detection.eval_handler import TamperDetectionEvaluator
from src.llm.dummy_llm import DummyLLM


class TestTamperDetectionDataHandler:
    """Tests for the TamperDetectionDataHandler class."""

    def test_init(self):
        """Test initialisation of TamperDetectionDataHandler."""
        # Mock the init_dataset method
        with mock.patch.object(TamperDetectionDataHandler, "init_dataset"):
            handler = TamperDetectionDataHandler("font")
            assert handler.font_semantics == "font"
            assert handler.original_b64_lst == []
            assert handler.original_type_lst == []
            assert handler.tampered_b64_lst == []
            assert handler.tampered_type_lst == []

    def test_init_dataset(self):
        """Test initialising the dataset."""
        # Mock the Path and open functions
        with mock.patch(
            "src.tasks.tamper_detection.data_handler.Path"
        ) as mock_path, mock.patch(
            "builtins.open", mock.mock_open(read_data=b"test_image_data")
        ), mock.patch(
            "src.tasks.tamper_detection.data_handler.Image.open"
        ) as mock_image_open, mock.patch(
            "src.tasks.tamper_detection.data_handler.config_manager.get",
            return_value="data",
        ):

            # Mock the PIL Image object
            mock_image = mock.MagicMock()
            mock_image.format = "JPEG"
            mock_image_open.return_value = mock_image

            # Mock the directory structure
            mock_font_dir = mock.MagicMock()
            mock_original_dir = mock.MagicMock()
            mock_tampered_dir = mock.MagicMock()

            # Set up the path traversal
            mock_path.return_value = mock.MagicMock()
            mock_path.return_value.__truediv__.return_value = mock_font_dir
            mock_font_dir.__truediv__.side_effect = [
                mock_original_dir,
                mock_tampered_dir,
            ]

            # Mock the file listing
            mock_file1 = mock.MagicMock()
            mock_file1.is_file.return_value = True
            mock_file1.name = "image1.jpg"
            mock_file1.__str__.return_value = "/path/to/image1.jpg"

            mock_file2 = mock.MagicMock()
            mock_file2.is_file.return_value = True
            mock_file2.name = "image2.jpg"
            mock_file2.__str__.return_value = "/path/to/image2.jpg"

            mock_original_dir.iterdir.return_value = [mock_file1, mock_file2]
            mock_tampered_dir.iterdir.return_value = [mock_file1, mock_file2]

            # initialise the data handler
            handler = TamperDetectionDataHandler("font")

            # Check that the data was loaded
            assert len(handler.original_b64_lst) == 2
            assert len(handler.original_type_lst) == 2
            assert len(handler.tampered_b64_lst) == 2
            assert len(handler.tampered_type_lst) == 2

            # Check that the image types are correct
            assert handler.original_type_lst[0] == "jpeg"
            assert handler.tampered_type_lst[0] == "jpeg"

            # Check that the base64 encoding is correct
            expected_b64 = base64.b64encode(b"test_image_data").decode()
            assert handler.original_b64_lst[0] == expected_b64
            assert handler.tampered_b64_lst[0] == expected_b64

    def test_get_data(self):
        """Test getting data for a specific ID."""
        # Create a data handler with mock data
        with mock.patch.object(TamperDetectionDataHandler, "init_dataset"):
            handler = TamperDetectionDataHandler("font")
            handler.original_b64_lst = ["original1", "original2"]
            handler.original_type_lst = ["jpeg", "png"]
            handler.tampered_b64_lst = ["tampered1", "tampered2"]
            handler.tampered_type_lst = ["jpeg", "png"]

        # Get data for ID 0
        data = handler.get_data(0)
        assert data == ("original1", "jpeg", "tampered1", "jpeg")

        # Get data for ID 1
        data = handler.get_data(1)
        assert data == ("original2", "png", "tampered2", "png")

    def test_get_size(self):
        """Test getting the size of the dataset."""
        # Create a data handler with mock data
        with mock.patch.object(TamperDetectionDataHandler, "init_dataset"):
            handler = TamperDetectionDataHandler("font")
            handler.original_b64_lst = ["original1", "original2"]

        # Check the size
        assert handler.get_size() == 2


class TestTamperDetectionEvaluator:
    """Tests for the TamperDetectionEvaluator class."""

    def setup_method(self):
        """Set up test fixtures."""
        self.evaluator = TamperDetectionEvaluator()

        # Create test data
        self.results = pd.DataFrame(
            {
                "img_id": [0, 1, 2, 3],
                "response_original": ["No", "No", "Yes", "Yes"],
                "response_tampered": ["Yes", "No", "Yes", "No"],
            }
        )

    def test_add_eval_score(self):
        """Test adding an eval_score column to the results."""
        # Add eval_score column
        results_with_score = self.evaluator._add_eval_score(self.results)

        # Check that the eval_score column was added
        assert "eval_score" in results_with_score.columns

        # Check the eval_score values
        assert results_with_score.loc[0, "eval_score"] == 1  # No -> Yes (correct)
        assert results_with_score.loc[1, "eval_score"] == 0  # No -> No (neutral)
        assert results_with_score.loc[2, "eval_score"] == 0  # Yes -> Yes (neutral)
        assert results_with_score.loc[3, "eval_score"] == -1  # Yes -> No (incorrect)

    def test_calculate_accuracy(self):
        """Test calculating accuracy."""
        # Calculate accuracy
        accuracy = self.evaluator._calculate_accuracy(self.results)

        # Check the accuracy
        # 1 correct out of 4 = 0.25
        assert accuracy == 0.25

    def test_calculate_precision(self):
        """Test calculating precision."""
        # Calculate precision
        precision = self.evaluator._calculate_precision(self.results)

        # Check the precision
        # 1 true positive (No -> Yes) out of 2 positives (response_tampered == "Yes")
        assert precision == 0.5

    def test_calculate_recall(self):
        """Test calculating recall."""
        # Calculate recall
        recall = self.evaluator._calculate_recall(self.results)

        # Check the recall
        # 1 true positive (No -> Yes) out of 4 actual positives
        assert recall == 0.25

    def test_calculate_f1_score(self):
        """Test calculating F1 score."""
        # Calculate F1 score
        f1_score = self.evaluator._calculate_f1_score(self.results)

        # Check the F1 score
        # 2 * (0.5 * 0.25) / (0.5 + 0.25) = 0.33333...
        assert f1_score == pytest.approx(0.3333, abs=0.001)

    def test_calculate_mean_score(self):
        """Test calculating mean score."""
        # Calculate mean score
        mean_score = self.evaluator._calculate_mean_score(self.results)

        # Check the mean score
        # (1 + 0 + 0 + -1) / 4 = 0
        assert mean_score == 0.0

    def test_get_eval_score(self):
        """Test getting the evaluation score."""
        # Get the evaluation score
        score = self.evaluator.get_eval_score(self.results)

        # Check that it's the same as the mean score
        assert score == self.evaluator._calculate_mean_score(self.results)


class TestTamperDetectionTask:
    """Tests for the TamperDetectionTask class."""

    def test_init(self):
        """Test initialisation of TamperDetectionTask."""
        # Mock the TamperDetectionDataHandler
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ):
            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=25,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "What can you do?"},
                ],
            )

            # Check the default parameters
            assert task.font_semantics == "font"
            assert task.num_images == 25
            assert len(task.prompt_msg_template) == 2
            assert task.prompt_msg_template[0]["role"] == "system"
            assert task.prompt_msg_template[1]["role"] == "user"

            # Create a task with custom parameters
            task = TamperDetectionTask(
                font_semantics="semantics",
                num_images=10,
                prompt_msg_template=[{"role": "system", "content": "Custom prompt"}],
            )

            # Check the custom parameters
            assert task.font_semantics == "semantics"
            assert task.num_images == 10
            assert task.prompt_msg_template == [
                {"role": "system", "content": "Custom prompt"}
            ]

    def test_load_data(self):
        """Test loading data."""
        # Mock the TamperDetectionDataHandler
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ) as mock_handler_class:
            # Mock the get_size method
            mock_handler = mock_handler_class.return_value
            mock_handler.get_size.return_value = 50

            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=25,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "What can you do?"},
                ],
            )

            # Load data
            data_ids = task.load_data()

            # Check that the correct number of data IDs were returned
            assert len(data_ids) == 25
            assert data_ids == list(range(25))

            # Test with num_images > get_size
            task.num_images = 100
            data_ids = task.load_data()
            assert len(data_ids) == 50
            assert data_ids == list(range(50))

    def test_create_prompt(self):
        """Test creating a prompt for a data item."""
        # Mock the TamperDetectionDataHandler
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ):
            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=25,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": "Determine if the image is tampered",
                            },
                            {
                                "type": "image_url",
                                "image_url": {"url": f"dummy_url"},
                            },
                        ],
                    },
                ],
            )

        # Create a prompt for a data item
        data_item = ("base64_image", "jpeg")
        prompt = task.create_prompt(data_item)

        # Check that the prompt was created correctly
        assert len(prompt) == 2
        assert prompt[0]["role"] == "system"
        assert prompt[1]["role"] == "user"
        assert prompt[1]["content"][0]["type"] == "text"
        assert prompt[1]["content"][1]["type"] == "image_url"
        assert (
            prompt[1]["content"][1]["image_url"]["url"]
            == "data:image/jpeg;base64,base64_image"
        )

    def test_run(self):
        """Test running the task."""
        # Mock the TamperDetectionDataHandler and ThreadPoolExecutor
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ) as mock_handler_class, mock.patch(
            "src.tasks.tamper_detection.task_handler.ThreadPoolExecutor"
        ) as mock_executor_class, mock.patch(
            "src.tasks.tamper_detection.task_handler.as_completed"
        ) as mock_as_completed:

            # Mock the get_size and get_data methods
            mock_handler = mock_handler_class.return_value
            mock_handler.get_size.return_value = 2
            mock_handler.get_data.side_effect = [
                ("original1", "jpeg", "tampered1", "jpeg"),
                ("original2", "jpeg", "tampered2", "jpeg"),
            ]

            # Create mock futures
            mock_future1 = mock.MagicMock()
            mock_future1.result.return_value = (0, "No", "Yes")

            mock_future2 = mock.MagicMock()
            mock_future2.result.return_value = (1, "Yes", "No")

            # Mock the executor
            mock_executor = mock_executor_class.return_value.__enter__.return_value
            mock_executor.submit.side_effect = [mock_future1, mock_future2]

            # Mock as_completed to return the futures in order
            mock_as_completed.return_value = [mock_future1, mock_future2]

            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=2,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "What can you do?"},
                ],
            )

            # Create a dummy LLM
            llm = DummyLLM(response="No")

            # Run the task
            results_df, eval_score, _ = task.run(llm)

            # Check the results
            assert isinstance(results_df, pd.DataFrame)
            assert "img_id" in results_df.columns
            assert "response_original" in results_df.columns
            assert "response_tampered" in results_df.columns
            assert len(results_df) == 2

            # Check that the evaluation score was calculated
            assert isinstance(eval_score, float)

    def test_evaluate(self):
        """Test evaluating the results."""
        # Mock the TamperDetectionDataHandler
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ):
            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=2,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "What can you do?"},
                ],
            )

        # Create test results
        results = pd.DataFrame(
            {
                "img_id": [0, 1],
                "response_original": ["No", "Yes"],
                "response_tampered": ["Yes", "No"],
            }
        )

        # Evaluate the results
        score = task.evaluate(results)

        # Check that the score was calculated
        assert isinstance(score, float)

    def test_get_prompt_msg_template(self):
        """Test getting the prompt message template."""
        # Mock the TamperDetectionDataHandler
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ):
            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=25,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "What can you do?"},
                ],
            )

        # Get the prompt message template
        template = task.get_prompt_msg_template()

        # Check that the template was returned
        assert template == task.prompt_msg_template

        # Check that it's a deep copy
        assert template is not task.prompt_msg_template

    def test_update_prompt_msg_template(self):
        """Test updating the prompt message template."""
        # Mock the TamperDetectionDataHandler
        with mock.patch(
            "src.tasks.tamper_detection.task_handler.TamperDetectionDataHandler"
        ):
            # Create a task
            task = TamperDetectionTask(
                font_semantics="font",
                num_images=25,
                prompt_msg_template=[
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": "What can you do?"},
                ],
            )

        # Define a new template
        new_template = [{"role": "system", "content": "New prompt"}]

        # Update the template
        task.update_prompt_msg_template(new_template)

        # Check that the template was updated
        assert task.prompt_msg_template == new_template

        # Check that it's a deep copy
        assert task.prompt_msg_template is not new_template
