"""Retry handling with exponential backoff and circuit breaker patterns."""

import time
import random
import logging
from dataclasses import dataclass
from typing import Optional, Callable, Any
from enum import Enum


class RetryStrategy(Enum):
    """Different retry strategies."""
    EXPONENTIAL_BACKOFF = "exponential"
    LINEAR_BACKOFF = "linear"
    FIXED_DELAY = "fixed"


@dataclass
class RetryPolicy:
    """Configuration for retry behavior."""
    max_attempts: int = 3
    base_delay: float = 1.0
    max_delay: float = 60.0
    exponential_base: float = 2.0
    strategy: RetryStrategy = RetryStrategy.EXPONENTIAL_BACKOFF
    jitter: bool = True
    
    def calculate_delay(self, attempt: int) -> float:
        """Calculate delay for a given attempt number."""
        if self.strategy == RetryStrategy.EXPONENTIAL_BACKOFF:
            delay = min(self.base_delay * (self.exponential_base ** attempt), self.max_delay)
        elif self.strategy == RetryStrategy.LINEAR_BACKOFF:
            delay = min(self.base_delay * (attempt + 1), self.max_delay)
        else:  # FIXED_DELAY
            delay = self.base_delay
        
        if self.jitter:
            # Add random jitter (±25% of delay)
            jitter_amount = delay * 0.25
            delay += random.uniform(-jitter_amount, jitter_amount)
        
        return max(0, delay)


class CircuitBreaker:
    """Circuit breaker to prevent cascading failures."""
    
    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout: float = 60.0,
        expected_exception: type = Exception
    ):
        """Initialize circuit breaker.
        
        Args:
            failure_threshold: Number of failures before opening circuit
            recovery_timeout: Time to wait before attempting recovery
            expected_exception: Exception type to catch
        """
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.expected_exception = expected_exception
        
        self.failure_count = 0
        self.last_failure_time = None
        self.state = "closed"  # closed, open, half-open
    
    def call(self, func: Callable, *args, **kwargs) -> Any:
        """Call function with circuit breaker protection."""
        if self.state == "open":
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = "half-open"
            else:
                raise Exception(f"Circuit breaker is open (failures: {self.failure_count})")
        
        try:
            result = func(*args, **kwargs)
            if self.state == "half-open":
                self.state = "closed"
                self.failure_count = 0
            return result
            
        except self.expected_exception as e:
            self.failure_count += 1
            self.last_failure_time = time.time()
            
            if self.failure_count >= self.failure_threshold:
                self.state = "open"
            
            raise e
    
    def reset(self):
        """Reset the circuit breaker."""
        self.failure_count = 0
        self.last_failure_time = None
        self.state = "closed"


class RetryHandler:
    """Handler for retry logic with various strategies."""
    
    def __init__(
        self,
        policy: Optional[RetryPolicy] = None,
        logger: Optional[logging.Logger] = None,
        use_circuit_breaker: bool = False
    ):
        """Initialize retry handler.
        
        Args:
            policy: Retry policy configuration
            logger: Logger instance
            use_circuit_breaker: Whether to use circuit breaker pattern
        """
        self.policy = policy or RetryPolicy()
        self.logger = logger or logging.getLogger(__name__)
        
        self.circuit_breaker = CircuitBreaker() if use_circuit_breaker else None
        self.retry_stats = {
            'total_retries': 0,
            'successful_retries': 0,
            'failed_retries': 0,
            'total_delay': 0.0
        }
    
    def execute_with_retry(
        self,
        func: Callable,
        *args,
        on_retry: Optional[Callable[[int, Exception], None]] = None,
        **kwargs
    ) -> Any:
        """Execute function with retry logic.
        
        Args:
            func: Function to execute
            *args: Positional arguments for function
            on_retry: Callback called on each retry with (attempt, exception)
            **kwargs: Keyword arguments for function
            
        Returns:
            Result of function execution
            
        Raises:
            Last exception if all retries fail
        """
        last_exception = None
        
        for attempt in range(self.policy.max_attempts):
            try:
                # Use circuit breaker if enabled
                if self.circuit_breaker:
                    result = self.circuit_breaker.call(func, *args, **kwargs)
                else:
                    result = func(*args, **kwargs)
                
                if attempt > 0:
                    self.retry_stats['successful_retries'] += 1
                    self.logger.info(f"Retry successful after {attempt} attempts")
                
                return result
                
            except Exception as e:
                last_exception = e
                self.retry_stats['total_retries'] += 1
                
                if attempt < self.policy.max_attempts - 1:
                    delay = self.policy.calculate_delay(attempt)
                    self.retry_stats['total_delay'] += delay
                    
                    self.logger.warning(
                        f"Attempt {attempt + 1}/{self.policy.max_attempts} failed: {str(e)}. "
                        f"Retrying in {delay:.2f} seconds..."
                    )
                    
                    if on_retry:
                        on_retry(attempt, e)
                    
                    time.sleep(delay)
                else:
                    self.retry_stats['failed_retries'] += 1
                    self.logger.error(f"All {self.policy.max_attempts} attempts failed")
        
        raise last_exception
    
    def wait_before_retry(self, attempt: int):
        """Wait before retry based on policy.
        
        Args:
            attempt: Current attempt number (0-based)
        """
        delay = self.policy.calculate_delay(attempt)
        self.logger.info(f"Waiting {delay:.2f} seconds before retry...")
        time.sleep(delay)
    
    def get_statistics(self) -> dict:
        """Get retry statistics."""
        return {
            **self.retry_stats,
            'average_delay': (
                self.retry_stats['total_delay'] / self.retry_stats['total_retries']
                if self.retry_stats['total_retries'] > 0 else 0
            ),
            'success_rate': (
                self.retry_stats['successful_retries'] / self.retry_stats['total_retries']
                if self.retry_stats['total_retries'] > 0 else 0
            )
        }
    
    def reset_statistics(self):
        """Reset retry statistics."""
        self.retry_stats = {
            'total_retries': 0,
            'successful_retries': 0,
            'failed_retries': 0,
            'total_delay': 0.0
        }
        
        if self.circuit_breaker:
            self.circuit_breaker.reset()
