"""Unit tests for the decorators module.

Tests timing, caching, validation, and other decorators.
"""

from __future__ import annotations

import time

import pytest
import torch
from torch import Tensor

from expected_gradcam.utils.decorators import (
    LRUCache,
    TensorSpec,
    TimingStats,
    cached,
    cached_method,
    gpu_sync,
    inference_mode,
    no_grad,
    require_grad,
    retry,
    timed,
    validate_input,
    validate_output,
)


class TestTimingStats:
    """Test TimingStats class."""

    def test_initial_values(self):
        """Test initial values."""
        stats = TimingStats()
        assert stats.total_time == 0.0
        assert stats.call_count == 0
        assert stats.avg_time == 0.0

    def test_record(self):
        """Test recording timing."""
        stats = TimingStats()
        stats.record(0.1)
        stats.record(0.2)

        assert stats.call_count == 2
        assert stats.total_time == pytest.approx(0.3)
        assert stats.avg_time == pytest.approx(0.15)
        assert stats.min_time == pytest.approx(0.1)
        assert stats.max_time == pytest.approx(0.2)

    def test_reset(self):
        """Test resetting stats."""
        stats = TimingStats()
        stats.record(0.1)
        stats.reset()

        assert stats.call_count == 0
        assert stats.total_time == 0.0


class TestTimedDecorator:
    """Test @timed decorator."""

    def test_basic_timing(self):
        """Test basic timing functionality."""

        @timed(collect_stats=True)
        def slow_function():
            time.sleep(0.01)
            return 42

        result = slow_function()

        assert result == 42
        assert slow_function.stats.call_count == 1
        assert slow_function.stats.total_time >= 0.01

    def test_with_logging(self, caplog):
        """Test timing with logging enabled."""
        import logging

        logging.getLogger().setLevel(logging.INFO)

        @timed(log=True, collect_stats=True)
        def logged_function():
            return "done"

        result = logged_function()
        assert result == "done"

    def test_stats_accumulation(self):
        """Test stats accumulate across calls."""

        @timed(collect_stats=True)
        def counter():
            return 1

        for _ in range(5):
            counter()

        assert counter.stats.call_count == 5

    def test_reset_stats(self):
        """Test resetting stats."""

        @timed(collect_stats=True)
        def func():
            return 1

        func()
        func()
        func.reset_stats()

        assert func.stats.call_count == 0


class TestLRUCache:
    """Test LRUCache class."""

    def test_basic_get_set(self):
        """Test basic get and set."""
        cache = LRUCache(maxsize=10)
        cache.set("key1", "value1")

        found, value = cache.get("key1")
        assert found is True
        assert value == "value1"

    def test_miss(self):
        """Test cache miss."""
        cache = LRUCache(maxsize=10)
        found, value = cache.get("nonexistent")

        assert found is False
        assert value is None

    def test_maxsize(self):
        """Test maxsize eviction."""
        cache = LRUCache(maxsize=3)

        cache.set("a", 1)
        cache.set("b", 2)
        cache.set("c", 3)
        cache.set("d", 4)  # Should evict "a"

        found_a, _ = cache.get("a")
        found_d, _ = cache.get("d")

        assert found_a is False
        assert found_d is True

    def test_ttl(self):
        """Test time-to-live expiration."""
        cache = LRUCache(maxsize=10, ttl=0.05)

        cache.set("key", "value")

        # Immediate access should work
        found1, _ = cache.get("key")
        assert found1 is True

        # After TTL, should expire
        time.sleep(0.1)
        found2, _ = cache.get("key")
        assert found2 is False

    def test_hit_rate(self):
        """Test hit rate calculation."""
        cache = LRUCache(maxsize=10)

        cache.set("a", 1)
        cache.get("a")  # Hit
        cache.get("a")  # Hit
        cache.get("b")  # Miss

        assert cache.hit_rate == pytest.approx(2 / 3)


class TestCachedDecorator:
    """Test @cached decorator."""

    def test_basic_caching(self):
        """Test basic function caching."""
        call_count = 0

        @cached(maxsize=10)
        def expensive(x):
            nonlocal call_count
            call_count += 1
            return x * 2

        result1 = expensive(5)
        result2 = expensive(5)

        assert result1 == 10
        assert result2 == 10
        assert call_count == 1  # Only called once

    def test_different_args(self):
        """Test caching with different arguments."""
        call_count = 0

        @cached(maxsize=10)
        def func(x):
            nonlocal call_count
            call_count += 1
            return x

        func(1)
        func(2)
        func(1)

        assert call_count == 2  # Called twice for different args

    def test_cache_info(self):
        """Test cache_info method."""

        @cached(maxsize=10)
        def func(x):
            return x

        func(1)
        func(1)
        func(2)

        info = func.cache_info()
        assert info["hits"] == 1
        assert info["misses"] == 2

    def test_cache_clear(self):
        """Test cache_clear method."""

        @cached(maxsize=10)
        def func(x):
            return x

        func(1)
        func.cache_clear()

        info = func.cache_info()
        assert info["size"] == 0

    def test_with_tensor_args(self):
        """Test caching with tensor arguments."""
        call_count = 0

        @cached(maxsize=10)
        def func(x: Tensor):
            nonlocal call_count
            call_count += 1
            return x.sum()

        t1 = torch.tensor([1.0, 2.0, 3.0])
        t2 = torch.tensor([1.0, 2.0, 3.0])  # Same content

        func(t1)
        func(t2)

        # Should cache based on tensor content
        assert call_count <= 2  # May or may not cache depending on hash


class TestTensorSpec:
    """Test TensorSpec validation."""

    def test_shape_validation(self):
        """Test shape validation."""
        spec = TensorSpec(shape=(1, 3, 224, 224))
        tensor = torch.randn(1, 3, 224, 224)

        errors = spec.validate(tensor, "image")
        assert len(errors) == 0

    def test_shape_mismatch(self):
        """Test shape mismatch detection."""
        spec = TensorSpec(shape=(1, 3, 224, 224))
        tensor = torch.randn(1, 3, 128, 128)

        errors = spec.validate(tensor, "image")
        assert len(errors) > 0

    def test_dynamic_dimensions(self):
        """Test dynamic dimensions with None."""
        spec = TensorSpec(shape=(None, 3, 224, 224))
        tensor = torch.randn(5, 3, 224, 224)

        errors = spec.validate(tensor, "image")
        assert len(errors) == 0

    def test_dtype_validation(self):
        """Test dtype validation."""
        spec = TensorSpec(dtype=torch.float32)

        tensor_good = torch.randn(10).float()
        tensor_bad = torch.randn(10).double()

        assert len(spec.validate(tensor_good, "x")) == 0
        assert len(spec.validate(tensor_bad, "x")) > 0

    def test_device_validation(self):
        """Test device validation."""
        spec = TensorSpec(device="cpu")
        tensor = torch.randn(10)

        errors = spec.validate(tensor, "x")
        assert len(errors) == 0

    def test_value_range(self):
        """Test value range validation."""
        spec = TensorSpec(min_val=0.0, max_val=1.0)

        tensor_good = torch.rand(10)
        tensor_bad = torch.randn(10) * 10

        assert len(spec.validate(tensor_good, "x")) == 0
        assert len(spec.validate(tensor_bad, "x")) > 0


class TestValidateInputDecorator:
    """Test @validate_input decorator."""

    def test_valid_input(self):
        """Test with valid input."""

        @validate_input(x=TensorSpec(ndim=4))
        def process(x):
            return x.sum()

        tensor = torch.randn(1, 3, 32, 32)
        result = process(tensor)
        assert result is not None

    def test_invalid_input_raises(self):
        """Test invalid input raises ValueError."""

        @validate_input(x=TensorSpec(ndim=4))
        def process(x):
            return x.sum()

        tensor = torch.randn(3, 32, 32)  # 3D, not 4D

        with pytest.raises(ValueError, match="expected 4D"):
            process(tensor)

    def test_dict_spec(self):
        """Test with dict spec instead of TensorSpec."""

        @validate_input(x={"ndim": 2, "dtype": torch.float32})
        def process(x):
            return x.sum()

        tensor = torch.randn(10, 10)
        result = process(tensor)
        assert result is not None


class TestValidateOutputDecorator:
    """Test @validate_output decorator."""

    def test_valid_output(self):
        """Test with valid output."""

        @validate_output(TensorSpec(ndim=2))
        def generate():
            return torch.randn(10, 10)

        result = generate()
        assert result.shape == (10, 10)

    def test_invalid_output_raises(self):
        """Test invalid output raises ValueError."""

        @validate_output(TensorSpec(ndim=2))
        def generate():
            return torch.randn(10, 10, 10)  # 3D

        with pytest.raises(ValueError, match="Output validation failed"):
            generate()


class TestGradientDecorators:
    """Test gradient control decorators."""

    def test_require_grad(self):
        """Test @require_grad decorator."""

        @require_grad
        def compute():
            return torch.randn(10)

        # Should work normally with grad enabled
        result = compute()
        assert result is not None

        # Should fail inside no_grad
        with torch.no_grad():
            with pytest.raises(RuntimeError, match="requires gradients"):
                compute()

    def test_no_grad_decorator(self):
        """Test @no_grad decorator."""

        @no_grad
        def inference(x):
            return x * 2

        x = torch.randn(10, requires_grad=True)
        result = inference(x)

        # Result should not require grad (computed in no_grad)
        assert not result.requires_grad

    def test_inference_mode_decorator(self):
        """Test @inference_mode decorator."""

        @inference_mode
        def fast_inference(x):
            return x * 2

        x = torch.randn(10)
        result = fast_inference(x)
        assert result is not None


class TestRetryDecorator:
    """Test @retry decorator."""

    def test_successful_on_first_try(self):
        """Test function succeeds on first try."""
        call_count = 0

        @retry(max_attempts=3)
        def succeed():
            nonlocal call_count
            call_count += 1
            return "success"

        result = succeed()
        assert result == "success"
        assert call_count == 1

    def test_retry_on_failure(self):
        """Test retry on failure."""
        call_count = 0

        @retry(max_attempts=3, delay=0.01)
        def fail_twice():
            nonlocal call_count
            call_count += 1
            if call_count < 3:
                raise ValueError("Not yet")
            return "success"

        result = fail_twice()
        assert result == "success"
        assert call_count == 3

    def test_max_attempts_exceeded(self):
        """Test failure after max attempts."""

        @retry(max_attempts=3, delay=0.01)
        def always_fail():
            raise ValueError("Always fails")

        with pytest.raises(ValueError, match="Always fails"):
            always_fail()

    def test_specific_exceptions(self):
        """Test retrying only specific exceptions."""

        @retry(max_attempts=3, delay=0.01, exceptions=(ValueError,))
        def fail_with_type_error():
            raise TypeError("Wrong type")

        with pytest.raises(TypeError):
            fail_with_type_error()


class TestGpuSyncDecorator:
    """Test @gpu_sync decorator."""

    def test_cpu_operation(self):
        """Test on CPU (should not error)."""

        @gpu_sync
        def cpu_op():
            return torch.randn(10)

        result = cpu_op()
        assert result is not None
