import torch
import torch.nn.functional as F
import math
from typing import Optional, Union, Dict, Any
from enum import Enum


class SamplingMethod(Enum):
    """Enumeration of available sampling methods"""
    BERNOULLI = "bernoulli"

    SIGMOID = "sigmoid"
    GUMBEL_SOFTMAX = "gumbel_softmax"
    CONCRETE = "concrete"  # Alias for Gumbel-Softmax
    STRAIGHT_THROUGH = "straight_through"
    GAUSSIAN = "gaussian"
    UNIFORM_THRESHOLD = "uniform_threshold"
    RELAXED_BERNOULLI = "relaxed_bernoulli"
    SPIKE_AND_SLAB = "spike_and_slab"

def select_sampling_method(method: Union[str, SamplingMethod]) -> SamplingMethod:
    """
    Convert string method name to SamplingMethod enum.

    Args:
        method: Method name as string or SamplingMethod enum

    Returns:
        Corresponding SamplingMethod enum value
    """
    if isinstance(method, str):
        try:
            return SamplingMethod[method.upper()]
        except KeyError:
            raise ValueError(f"Unknown sampling method: {method}")
    elif isinstance(method, SamplingMethod):
        return method
    else:
        raise TypeError(f"Expected str or SamplingMethod, got {type(method)}")

class ConditionalSampler:
    """
    Enhanced conditional sampling class with multiple probabilistic sampling methods
    optimized for Restricted Boltzmann Machines and similar probabilistic models.
    """

    def __init__(self, device: torch.device = None):
        self.device = device or torch.device('cpu')
        self.eps = 1e-8  # Numerical stability constant

    def sample(self,
               input_tensor: Union[torch.Tensor, Dict[str, torch.Tensor]],
               method: Union[str, SamplingMethod] = SamplingMethod.GUMBEL_SOFTMAX,
               temperature: float = 1.0,
               **kwargs) -> Dict[str, torch.Tensor]:
        """
        Enhanced sampling with multiple methods and detailed outputs.

        Args:
            input_tensor: Input logits/probabilities (Tensor) or previous sampling result (Dict)
            method: Sampling method to use
            temperature: Temperature parameter for controlling randomness
            **kwargs: Additional method-specific parameters

        Returns:
            Dictionary containing:
            - 'samples': Sampled values
            - 'probabilities': Input probabilities (normalized)
            - 'log_probs': Log probabilities
            - 'entropy': Entropy of the distribution
            - 'method_info': Method-specific information
        """
        # Handle both tensor and dictionary inputs
        if isinstance(input_tensor, dict):
            # If dictionary, extract the tensor (support chaining operations)
            if 'samples' in input_tensor:
                actual_tensor = input_tensor['samples']
            elif 'probabilities' in input_tensor:
                actual_tensor = input_tensor['probabilities']
            else:
                raise ValueError("Dictionary input must contain 'samples' or 'probabilities' key")
        else:
            actual_tensor = input_tensor

        # Convert string to enum if necessary
        if isinstance(method, str):
            method = SamplingMethod(method.lower())

        # Ensure input is properly formatted
        actual_tensor = actual_tensor.to(self.device)

        # Convert logits to probabilities if needed
        # Add option to force clamping behavior
        force_clamp = kwargs.get('force_clamp', False)

        if force_clamp:
            probs = torch.clamp(actual_tensor, min=0.0, max=1.0)
        elif actual_tensor.min() < 0 or actual_tensor.max() > 1:
            probs = torch.sigmoid(actual_tensor)
        else:
            probs = torch.clamp(actual_tensor, min=self.eps, max=1.0 - self.eps)

        # Calculate base statistics
        log_probs = torch.log(probs + self.eps)
        entropy = -probs * log_probs - (1 - probs) * torch.log(1 - probs + self.eps)

        # Route to appropriate sampling method
        sampling_methods = {
            SamplingMethod.BERNOULLI: self._bernoulli_sampling,
            SamplingMethod.SIGMOID: self._sigmoid_sampling,
            SamplingMethod.GUMBEL_SOFTMAX: self._gumbel_softmax_sampling,
            SamplingMethod.CONCRETE: self._gumbel_softmax_sampling,  # Alias
            SamplingMethod.STRAIGHT_THROUGH: self._straight_through_sampling,
            SamplingMethod.GAUSSIAN: self._gaussian_sampling,
            SamplingMethod.UNIFORM_THRESHOLD: self._uniform_threshold_sampling,
            SamplingMethod.RELAXED_BERNOULLI: self._relaxed_bernoulli_sampling,
            SamplingMethod.SPIKE_AND_SLAB: self._spike_and_slab_sampling,
        }

        samples, method_info = sampling_methods[method](probs, temperature, **kwargs)

        # return {
        #     'samples': samples,
        #     'probabilities': probs,
        #     'log_probs': log_probs,
        #     'entropy': entropy.mean(),
        #     'method_info': method_info,
        #     'temperature': temperature
        # }

        return samples

    def _bernoulli_sampling(self, probs: torch.Tensor, temperature: float, **kwargs) -> tuple:
        """Classic Bernoulli sampling - best for traditional RBMs"""
        samples = torch.bernoulli(probs)

        method_info = {
            'type': 'discrete',
            'differentiable': False,
            'description': 'Classic binary sampling for RBMs'
        }

        return samples, method_info

    def _sigmoid_sampling(self, probs: torch.Tensor, temperature: float, **kwargs) -> tuple:
        """Classic sigmoid"""
        samples = torch.sigmoid(probs)

        method_info = {
            'type': 'sigmoid',
            'differentiable': False,
            'description': 'Classic sigmoid activation sampling for RBMs'
        }

        return samples, method_info

    def _gumbel_softmax_sampling(self, probs: torch.Tensor, temperature: float,
                                 hard: bool = False, **kwargs) -> tuple:
        """
        Gumbel-Softmax sampling - BEST for modern differentiable RBMs

        Args:
            hard: If True, returns discrete samples with straight-through gradients
        """
        # Gumbel noise
        gumbel_noise = self._sample_gumbel(probs.shape)

        # Gumbel-Softmax computation
        logits = torch.log(probs + self.eps) - torch.log(1 - probs + self.eps)
        y = (logits + gumbel_noise) / temperature
        samples = torch.sigmoid(y)

        # Hard version with straight-through estimator
        if hard:
            samples_hard = (samples > 0.5).float()
            samples = samples_hard - samples.detach() + samples

        method_info = {
            'type': 'continuous' if not hard else 'discrete_with_gradients',
            'differentiable': True,
            'description': 'Differentiable relaxation of Bernoulli sampling',
            'hard_sampling': hard,
            'effective_temperature': temperature
        }

        return samples, method_info

    def _straight_through_sampling(self, probs: torch.Tensor, temperature: float, **kwargs) -> tuple:
        """Straight-through estimator with Bernoulli sampling"""
        # Forward pass: discrete sampling
        samples_discrete = torch.bernoulli(probs)

        # Backward pass: use continuous probabilities
        samples = samples_discrete - probs.detach() + probs

        method_info = {
            'type': 'discrete_with_gradients',
            'differentiable': True,
            'description': 'Discrete sampling with straight-through gradients'
        }

        return samples, method_info

    def _gaussian_sampling(self, probs: torch.Tensor, temperature: float,
                           std: float = 0.1, **kwargs) -> tuple:
        """Gaussian sampling around probabilities"""
        noise = torch.randn_like(probs) * std * temperature
        samples = torch.clamp(probs + noise, 0, 1)

        method_info = {
            'type': 'continuous',
            'differentiable': True,
            'description': 'Gaussian noise around probabilities',
            'noise_std': std * temperature
        }

        return samples, method_info

    def _uniform_threshold_sampling(self, probs: torch.Tensor, temperature: float, **kwargs) -> tuple:
        """Uniform threshold sampling"""
        uniform_noise = torch.rand_like(probs)
        # Temperature affects the sharpness of the threshold
        adjusted_probs = torch.pow(probs, 1.0 / temperature)
        samples = (uniform_noise < adjusted_probs).float()

        method_info = {
            'type': 'discrete',
            'differentiable': False,
            'description': 'Threshold sampling with uniform noise'
        }

        return samples, method_info

    def _relaxed_bernoulli_sampling(self, probs: torch.Tensor, temperature: float, **kwargs) -> tuple:
        """Relaxed Bernoulli (Binary Concrete) distribution"""
        # Sample from uniform distribution
        u = torch.rand_like(probs)
        u = torch.clamp(u, self.eps, 1 - self.eps)

        # Apply inverse sigmoid transform
        logistic_noise = torch.log(u) - torch.log(1 - u)
        logits = torch.log(probs + self.eps) - torch.log(1 - probs + self.eps)

        # Apply temperature scaling
        samples = torch.sigmoid((logits + logistic_noise) / temperature)

        method_info = {
            'type': 'continuous',
            'differentiable': True,
            'description': 'Relaxed Bernoulli distribution (Binary Concrete)'
        }

        return samples, method_info

    def _spike_and_slab_sampling(self, probs: torch.Tensor, temperature: float,
                                 slab_std: float = 0.1, **kwargs) -> tuple:
        """
        Spike-and-slab sampling for sparse representations
        """
        # Spike component (discrete)
        spike_samples = torch.bernoulli(probs)

        # Slab component (continuous)
        slab_samples = torch.randn_like(probs) * slab_std * temperature + probs
        slab_samples = torch.clamp(slab_samples, 0, 1)

        # Combine based on spike activation
        samples = spike_samples * slab_samples

        method_info = {
            'type': 'mixed',
            'differentiable': False,
            'description': 'Spike-and-slab for sparse representations',
            'sparsity_level': (1 - probs.mean()).item()
        }

        return samples, method_info

    def _sample_gumbel(self, shape: torch.Size) -> torch.Tensor:
        """Sample from Gumbel(0, 1) distribution"""
        u = torch.rand(shape, device=self.device)
        u = torch.clamp(u, self.eps, 1 - self.eps)
        return -torch.log(-torch.log(u))

    def anneal_temperature(self, initial_temp: float, final_temp: float,
                           step: int, total_steps: int,
                           schedule: str = 'exponential') -> float:
        """
        Temperature annealing schedules for training

        Args:
            initial_temp: Starting temperature
            final_temp: Final temperature
            step: Current training step
            total_steps: Total training steps
            schedule: Annealing schedule ('exponential', 'linear', 'cosine')
        """
        progress = min(step / total_steps, 1.0)

        if schedule == 'exponential':
            return initial_temp * (final_temp / initial_temp) ** progress
        elif schedule == 'linear':
            return initial_temp + (final_temp - initial_temp) * progress
        elif schedule == 'cosine':
            return final_temp + (initial_temp - final_temp) * 0.5 * (1 + math.cos(math.pi * progress))
        else:
            raise ValueError(f"Unknown schedule: {schedule}")

    def get_method_recommendation(self, use_case: str) -> SamplingMethod:
        """Get recommended sampling method based on use case"""
        recommendations = {
            'traditional_rbm': SamplingMethod.BERNOULLI,
            'differentiable_rbm': SamplingMethod.GUMBEL_SOFTMAX,
            'fast_training': SamplingMethod.STRAIGHT_THROUGH,
            'continuous_relaxation': SamplingMethod.RELAXED_BERNOULLI,
            'sparse_representation': SamplingMethod.SPIKE_AND_SLAB,
            'stable_gradients': SamplingMethod.GAUSSIAN
        }

        return recommendations.get(use_case.lower(), SamplingMethod.GUMBEL_SOFTMAX)


# Utility functions for easy usage
def conditional_sampling(input_tensor: Union[torch.Tensor, Dict],
                         condition: str = 'gumbel_softmax',
                         temperature: float = 1.0,
                         **kwargs) -> torch.Tensor:
    """
    Simple wrapper function for backward compatibility and quick usage.
    Maintains exact behavior of original function.

    Args:
        input_tensor: Input tensor or previous sampling result
        condition: Sampling method name
        temperature: Temperature parameter
        **kwargs: Additional parameters

    Returns:
        torch.Tensor: Sampled values only (for backward compatibility)
    """
    sampler = ConditionalSampler()

    # Force original clamping behavior for backward compatibility
    kwargs['force_clamp'] = True

    # Map original condition names to new method names
    condition_mapping = {
        'gumbel': 'gumbel_softmax',
        'normal': 'gaussian',
        'uniform': 'uniform_threshold'
    }

    mapped_condition = condition_mapping.get(condition, condition)
    result = sampler.sample(input_tensor, method=mapped_condition, temperature=temperature, **kwargs)
    return result['samples']


def safe_tensor_extract(input_data: Union[torch.Tensor, Dict, Any]) -> torch.Tensor:
    """
    Safely extract tensor from various input types.

    Args:
        input_data: Input that might be tensor, dict, or other type

    Returns:
        torch.Tensor: Extracted tensor

    Raises:
        TypeError: If input cannot be converted to tensor
    """
    if isinstance(input_data, torch.Tensor):
        return input_data
    elif isinstance(input_data, dict):
        # Try common keys
        for key in ['samples', 'probabilities', 'logits', 'values', 'data']:
            if key in input_data and isinstance(input_data[key], torch.Tensor):
                return input_data[key]
        raise ValueError(
            f"Dictionary input must contain a tensor in one of these keys: {['samples', 'probabilities', 'logits', 'values', 'data']}")
    elif hasattr(input_data, 'data') and isinstance(input_data.data, torch.Tensor):
        return input_data.data
    else:
        try:
            return torch.tensor(input_data)
        except:
            raise TypeError(f"Cannot convert input of type {type(input_data)} to torch.Tensor")


# Example usage and demonstration
def demonstrate_enhanced_sampling():
    """Demonstrate the enhanced sampling capabilities"""

    # Create sampler
    sampler = ConditionalSampler()

    # Example input (could be RBM hidden unit activations)
    input_logits = torch.randn(32, 100)  # Batch of 32, 100 hidden units

    print("Enhanced Conditional Sampling Demonstration")
    print("=" * 50)

    # Test different methods
    methods_to_test = [
        SamplingMethod.BERNOULLI,
        SamplingMethod.GUMBEL_SOFTMAX,
        SamplingMethod.STRAIGHT_THROUGH,
        SamplingMethod.RELAXED_BERNOULLI
    ]

    for method in methods_to_test:
        result = sampler.sample(input_logits, method=method, temperature=0.5)

        print(f"\nMethod: {method.value}")
        print(f"Sample shape: {result['samples'].shape}")
        print(f"Sample range: [{result['samples'].min():.3f}, {result['samples'].max():.3f}]")
        print(f"Mean entropy: {result['entropy']:.3f}")
        print(f"Differentiable: {result['method_info']['differentiable']}")
        print(f"Type: {result['method_info']['type']}")

    # Demonstrate chaining operations (passing dict as input)
    print(f"\nChaining Operations Example:")
    first_result = sampler.sample(input_logits, method='gumbel_softmax', temperature=2.0)
    second_result = sampler.sample(first_result, method='bernoulli', temperature=0.5)  # This should work now
    print(f"First sampling: {first_result['samples'].shape}")
    print(f"Second sampling: {second_result['samples'].shape}")

    # Test backward compatibility function
    print(f"\nBackward Compatibility Test:")
    simple_samples = conditional_sampling(input_logits, 'gumbel_softmax', temperature=1.0)
    print(f"Simple function output shape: {simple_samples.shape}")

    # Demonstrate temperature annealing
    print(f"\nTemperature Annealing Example:")
    for step in [0, 250, 500, 750, 1000]:
        temp = sampler.anneal_temperature(2.0, 0.1, step, 1000, 'exponential')
        print(f"Step {step:4d}: Temperature = {temp:.3f}")

    # Method recommendations
    print(f"\nMethod Recommendations:")
    use_cases = ['traditional_rbm', 'differentiable_rbm', 'fast_training', 'sparse_representation']
    for use_case in use_cases:
        recommended = sampler.get_method_recommendation(use_case)
        print(f"{use_case}: {recommended.value}")


if __name__ == "__main__":
    demonstrate_enhanced_sampling()