"""
Tests for the experiment module.
"""

import json
import pathlib
import tempfile
from unittest import mock

import pandas as pd
import pytest
import yaml

from src.core.experiment import (
    ExperimentRunner,
    ExperimentStep,
    load_experiment_results,
)
from src.utils.error_tracking import ErrorTracker


class TestExperimentStep:
    """Tests for the ExperimentStep class."""

    def test_init(self):
        """Test initialisation of ExperimentStep."""

        # Define a test function
        def test_func(a, b):
            return a + b

        # Create an experiment step
        step = ExperimentStep("test_step", test_func, a=1, b=2)

        # Check the attributes
        assert step.name == "test_step"
        assert step.func == test_func
        assert step.kwargs == {"a": 1, "b": 2}

    def test_execute(self):
        """Test executing an experiment step."""

        # Define a test function
        def test_func(a, b, c=None):
            if c is not None:
                return a + b + c
            return a + b

        # Create an experiment step
        step = ExperimentStep("test_step", test_func, a=1, b=2)

        # Create a mock logger
        mock_logger = mock.MagicMock()

        # Execute the step without previous results
        result = step.execute(logger=mock_logger)
        assert result == 3

        # Execute the step with previous results
        result = step.execute(previous_results={"c": 3}, logger=mock_logger)
        assert result == 6

        # Execute the step with previous results that don't override kwargs
        result = step.execute(previous_results={"a": 10, "c": 3}, logger=mock_logger)
        assert (
            result == 6
        )  # a=1 (from kwargs), b=2 (from kwargs), c=3 (from previous_results)


class TestExperimentRunner:
    """Tests for the ExperimentRunner class."""

    def setup_method(self):
        """Set up test fixtures."""
        # Create temporary directory for output
        self.temp_dir = tempfile.TemporaryDirectory()
        self.output_dir = pathlib.Path(self.temp_dir.name)

        # Create a test config
        self.config = {
            "project": {
                "name": "test_project",
                "version": "0.1.0",
            },
            "paths": {
                "data": "data",
                "output": str(self.output_dir),
            },
            "test": {
                "param1": "value1",
                "param2": 42,
            },
        }

    def teardown_method(self):
        """Tear down test fixtures."""
        self.temp_dir.cleanup()

    def test_init(self):
        """Test initialisation of ExperimentRunner."""
        # Test with config dictionary
        runner = ExperimentRunner("test_experiment", config=self.config)
        assert runner.name == "test_experiment"
        assert runner.config == self.config
        assert runner.output_dir == pathlib.Path(self.config["paths"]["output"])
        assert runner.steps == []
        assert runner.results == {}
        assert runner.metrics == {}
        assert runner.start_time is None
        assert runner.end_time is None

        # Test with config_name (mocked)
        with mock.patch(
            "src.core.experiment.config_manager.load_config", return_value=self.config
        ):
            runner = ExperimentRunner("test_experiment", config_name="test_config")
            assert runner.config == self.config

        # Test with custom output_dir
        custom_output = pathlib.Path("custom/output")
        runner = ExperimentRunner(
            "test_experiment", config=self.config, output_dir=custom_output
        )
        assert runner.output_dir == custom_output

    def test_get_config(self):
        """Test getting the configuration."""
        runner = ExperimentRunner("test_experiment", config=self.config)
        assert runner.get_config() == self.config

    def test_add_step(self):
        """Test adding a step to the experiment."""
        # Create an experiment runner
        runner = ExperimentRunner("test_experiment", config=self.config)

        # Define a test function
        def test_func(a, b):
            return a + b

        # Add a step
        result = runner.add_step("test_step", test_func, a=1, b=2)

        # Check that the step was added
        assert len(runner.steps) == 1
        assert runner.steps[0].name == "test_step"
        assert runner.steps[0].func == test_func
        assert runner.steps[0].kwargs == {"a": 1, "b": 2}

        # Check that the method returns the runner for chaining
        assert result is runner

    def test_run(self):
        """Test running an experiment."""
        # Create an experiment runner
        runner = ExperimentRunner(
            "test_experiment", config=self.config, output_dir=self.output_dir
        )

        # Define test functions
        def step1(a, b):
            return a + b

        def step2(step1, c):
            return step1 * c

        # Add steps
        runner.add_step("step1", step1, a=2, b=3)
        runner.add_step("step2", step2, c=4)

        # Mock the _save_results method
        with mock.patch.object(runner, "_save_results"):
            # Run the experiment
            results = runner.run()

            # Check the results
            assert results["step1"] == 5
            assert results["step2"] == 20

            # Check that the metrics were updated
            assert "duration" in runner.metrics
            assert runner.start_time is not None
            assert runner.end_time is not None

    def test_add_metric(self):
        """Test adding a metric to the experiment."""
        # Create an experiment runner
        runner = ExperimentRunner("test_experiment", config=self.config)

        # Add a metric
        runner.add_metric("accuracy", 0.95)

        # Check that the metric was added
        assert runner.metrics["accuracy"] == 0.95

    def test_save_results(self):
        """Test saving experiment results."""
        # Create an experiment runner
        runner = ExperimentRunner(
            "test_experiment", config=self.config, output_dir=self.output_dir
        )

        # Set up some results and metrics
        runner.results = {
            "step1": 5,
            "step2": 20,
            "dataframe": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}),
            "non_serializable": object(),  # This should be skipped
        }
        runner.metrics = {
            "accuracy": 0.95,
            "duration": 1.23,
        }

        # Mock time.strftime to return a fixed timestamp
        with mock.patch("time.strftime", return_value="20250625-150000"):
            # Save the results
            runner._save_results()

            # Check that the output directory was created
            exp_dir = self.output_dir / "test_experiment_20250625-150000"
            assert exp_dir.exists()

            # Check that the config was saved
            config_path = exp_dir / "config.yaml"
            assert config_path.exists()
            with open(config_path, "r") as f:
                saved_config = yaml.safe_load(f)
            assert saved_config == self.config

            # Check that metrics.json is no longer created (metrics are in experiment_summary.json)
            metrics_path = exp_dir / "metrics.json"
            assert not metrics_path.exists()

            # Check that non-serializable results were skipped
            non_serializable_path = exp_dir / "non_serializable.json"
            assert not non_serializable_path.exists()

    def test_compare(self):
        """Test comparing experiments."""
        # Create two experiment runners
        runner1 = ExperimentRunner("experiment1", config=self.config)
        runner2 = ExperimentRunner("experiment2", config=self.config)

        # Add metrics to both runners
        runner1.metrics = {
            "accuracy": 0.95,
            "precision": 0.90,
            "recall": 0.85,
        }

        runner2.metrics = {
            "accuracy": 0.92,
            "precision": 0.88,
            "f1": 0.90,
        }

        # Compare the runners
        comparison = runner1.compare(runner2)

        # Check the comparison results
        assert comparison["accuracy"] == (0.95, 0.92)
        assert comparison["precision"] == (0.90, 0.88)

        # For NaN values, we need to check differently
        assert comparison["recall"][0] == 0.85
        assert pd.isna(comparison["recall"][1])

        assert pd.isna(comparison["f1"][0])
        assert comparison["f1"][1] == 0.90

    def test_save_results_with_error_tracker(self):
        """Test saving experiment results when error tracker is present."""
        # Create an experiment runner
        runner = ExperimentRunner(
            "test_experiment", config=self.config, output_dir=self.output_dir
        )

        # Create a mock error tracker with some errors
        error_tracker = ErrorTracker()
        error_tracker.record_error(
            error=ValueError("Test error 1"),
            input_messages=[{"role": "user", "content": "Test question 1"}],
            model_config={"name": "test-model", "version": "1.0"},
            retry_attempts=3,
            task_id="task_1",
        )
        error_tracker.record_error(
            error=RuntimeError("Test error 2"),
            input_messages=[{"role": "user", "content": "Test question 2"}],
            model_config={"name": "test-model", "version": "1.0"},
            retry_attempts=2,
            task_id="task_2",
        )

        # Set up results that include error tracker (simulating QATask.run() return)
        runner.results = {
            "step1": 5,
            "run_task": (
                pd.DataFrame({"question": ["Q1", "Q2"], "answer": [1, 2]}),
                0.85,  # score
                error_tracker,  # This is the third element that should be detected
            ),
            "step2": 20,
        }
        runner.metrics = {
            "accuracy": 0.85,
            "duration": 1.23,
        }

        # Mock time.strftime to return a fixed timestamp
        with mock.patch("time.strftime", return_value="20250625-150000"):
            # Save the results
            runner._save_results()

            # Check that the output directory was created
            exp_dir = self.output_dir / "test_experiment_20250625-150000"
            assert exp_dir.exists()

            # Check that errors.json was created
            errors_path = exp_dir / "errors.json"
            assert errors_path.exists()

            # Load and verify the errors file
            with open(errors_path, "r") as f:
                saved_errors = json.load(f)

            assert len(saved_errors) == 2
            assert saved_errors[0]["error_message"] == "Test error 1"
            assert saved_errors[0]["error_type"] == "ValueError"
            assert saved_errors[0]["task_id"] == "task_1"
            assert saved_errors[0]["retry_attempts"] == 3

            assert saved_errors[1]["error_message"] == "Test error 2"
            assert saved_errors[1]["error_type"] == "RuntimeError"
            assert saved_errors[1]["task_id"] == "task_2"
            assert saved_errors[1]["retry_attempts"] == 2

            # Check that metrics.json is no longer created (metrics are in experiment_summary.json)
            metrics_path = exp_dir / "metrics.json"
            assert not metrics_path.exists()

            # Verify error statistics were added to runner metrics in memory
            assert runner.metrics["accuracy"] == 0.85
            assert runner.metrics["duration"] == 1.23
            assert runner.metrics["failed_llm_calls"] == 2
            assert runner.metrics["error_types"]["ValueError"] == 1
            assert runner.metrics["error_types"]["RuntimeError"] == 1
            assert len(runner.metrics["errors"]) == 2

    def test_save_results_without_error_tracker(self):
        """Test saving experiment results when no error tracker is present."""
        # Create an experiment runner
        runner = ExperimentRunner(
            "test_experiment", config=self.config, output_dir=self.output_dir
        )

        # Set up results without error tracker
        runner.results = {
            "step1": 5,
            "step2": 20,
        }
        runner.metrics = {
            "accuracy": 0.95,
            "duration": 1.23,
        }

        # Mock time.strftime to return a fixed timestamp
        with mock.patch("time.strftime", return_value="20250625-150000"):
            # Save the results
            runner._save_results()

            # Check that the output directory was created
            exp_dir = self.output_dir / "test_experiment_20250625-150000"
            assert exp_dir.exists()

            # Check that errors.json was NOT created
            errors_path = exp_dir / "errors.json"
            assert not errors_path.exists()

            # Check that metrics.json is no longer created (metrics are in experiment_summary.json)
            metrics_path = exp_dir / "metrics.json"
            assert not metrics_path.exists()

            # Verify metrics in runner object
            assert runner.metrics["accuracy"] == 0.95
            assert runner.metrics["duration"] == 1.23
            assert "failed_llm_calls" not in runner.metrics
            assert "error_types" not in runner.metrics

    def test_save_results_with_empty_error_tracker(self):
        """Test saving experiment results when error tracker exists but has no errors."""
        # Create an experiment runner
        runner = ExperimentRunner(
            "test_experiment", config=self.config, output_dir=self.output_dir
        )

        # Create an empty error tracker
        error_tracker = ErrorTracker()

        # Set up results with empty error tracker
        runner.results = {
            "step1": 5,
            "run_task": (
                pd.DataFrame({"question": ["Q1"], "answer": [1]}),
                0.95,  # score
                error_tracker,  # Empty error tracker
            ),
        }
        runner.metrics = {
            "accuracy": 0.95,
            "duration": 1.23,
        }

        # Mock time.strftime to return a fixed timestamp
        with mock.patch("time.strftime", return_value="20250625-150000"):
            # Save the results
            runner._save_results()

            # Check that the output directory was created
            exp_dir = self.output_dir / "test_experiment_20250625-150000"
            assert exp_dir.exists()

            # Check that errors.json was NOT created (empty error tracker)
            errors_path = exp_dir / "errors.json"
            assert not errors_path.exists()

            # Check that metrics.json is no longer created (metrics are in experiment_summary.json)
            metrics_path = exp_dir / "metrics.json"
            assert not metrics_path.exists()

            # Verify metrics in runner object
            assert runner.metrics["accuracy"] == 0.95
            assert runner.metrics["duration"] == 1.23
            assert "failed_llm_calls" not in runner.metrics


class TestLoadExperimentResults:
    """Tests for the load_experiment_results function."""

    def setup_method(self):
        """Set up test fixtures."""
        # Create temporary directory for experiment results
        self.temp_dir = tempfile.TemporaryDirectory()
        self.exp_dir = pathlib.Path(self.temp_dir.name)

        # Create test files
        self.config = {
            "project": {
                "name": "test_project",
                "version": "0.1.0",
            },
        }
        with open(self.exp_dir / "config.yaml", "w") as f:
            yaml.dump(self.config, f)

        self.metrics = {
            "accuracy": 0.95,
            "duration": 1.23,
        }
        # Note: metrics.json is no longer created - metrics are included in experiment_summary.json

        # Create a mock experiment log file
        with open(self.exp_dir / "experiment.log", "w") as f:
            f.write(
                "2025-07-01 10:48:00 - experiment.test_experiment - INFO - Experiment step results:\n"
            )
            f.write(
                '2025-07-01 10:48:00 - experiment.test_experiment - INFO -   result1: {"value": 42, "list": [1, 2, 3]}\n'
            )

    def teardown_method(self):
        """Tear down test fixtures."""
        self.temp_dir.cleanup()

    def test_load_experiment_results(self):
        """Test loading experiment results."""
        # Load the results
        config, results = load_experiment_results(self.exp_dir)

        # Check the loaded config
        assert config == self.config

        # Check the loaded results
        # In the current implementation, results are only available in the log file
        # and not loaded into the results dictionary
        assert "_log_file" in results
        assert str(self.exp_dir / "experiment.log") == results["_log_file"]

    def test_load_experiment_results_nonexistent(self):
        """Test loading experiment results from a non-existent directory."""
        # Try to load results from a non-existent directory
        with pytest.raises(FileNotFoundError):
            load_experiment_results(self.exp_dir / "nonexistent")

    def test_load_experiment_results_missing_files(self):
        """Test loading experiment results with missing files."""
        # Create a new directory with no files
        empty_dir = self.exp_dir / "empty"
        empty_dir.mkdir()

        # Load the results
        config, results = load_experiment_results(empty_dir)

        # Check that empty dictionaries were returned
        assert config == {}
        assert results == {}

    def test_load_experiment_results_with_errors(self):
        """Test loading experiment results that include error information."""
        # Create test error data
        test_errors = [
            {
                "timestamp": 1672531200.0,
                "error_message": "Test error 1",
                "error_type": "ValueError",
                "input_messages": [{"role": "user", "content": "Test question 1"}],
                "model_config": {"name": "test-model", "version": "1.0"},
                "retry_attempts": 3,
                "task_id": "task_1",
            },
            {
                "timestamp": 1672531260.0,
                "error_message": "Test error 2",
                "error_type": "RuntimeError",
                "input_messages": [{"role": "user", "content": "Test question 2"}],
                "model_config": {"name": "test-model", "version": "1.0"},
                "retry_attempts": 2,
                "task_id": "task_2",
            },
        ]

        # Save errors.json file
        with open(self.exp_dir / "errors.json", "w") as f:
            json.dump(test_errors, f)

        # Load the results
        config, results = load_experiment_results(self.exp_dir)

        # Check that errors were loaded
        assert "_errors" in results
        assert len(results["_errors"]) == 2
        assert results["_errors"][0]["error_message"] == "Test error 1"
        assert results["_errors"][0]["error_type"] == "ValueError"
        assert results["_errors"][0]["task_id"] == "task_1"
        assert results["_errors"][1]["error_message"] == "Test error 2"
        assert results["_errors"][1]["error_type"] == "RuntimeError"
        assert results["_errors"][1]["task_id"] == "task_2"

        # Check that other data is still loaded correctly
        assert config == self.config
        assert "_log_file" in results

    def test_load_experiment_results_with_empty_errors_file(self):
        """Test loading experiment results with an empty errors.json file."""
        # Create empty errors.json file
        with open(self.exp_dir / "errors.json", "w") as f:
            json.dump([], f)

        # Load the results
        config, results = load_experiment_results(self.exp_dir)

        # Check that empty errors list was handled correctly
        # Empty errors list should not be included in results
        assert "_errors" not in results or results["_errors"] == []

        # Check that other data is still loaded correctly
        assert config == self.config

    def test_load_experiment_results_without_metrics_json(self):
        """Test loading experiment results when metrics.json no longer exists."""
        # Load the results
        config, results = load_experiment_results(self.exp_dir)

        # Check that other data is still loaded correctly
        assert config == self.config
