import numpy as np
import warnings
import torch
from escort.escort_improvements import ImprovedESCORT, ImprovedSVGD, ImprovedRBFKernel, ImprovedGSWD

class NoProjectionGSWD(ImprovedGSWD):
    """
    Modified GSWD that uses identity matrices instead of optimized projections.
    """
    def __init__(self, n_projections=10, projection_method='random', 
                 optimization_steps=0, correlation_aware=False,
                 learning_rate=0.01, momentum=0.9):
        """
        Initialize NoProjectionGSWD with correlation awareness disabled.
        """
        super().__init__(
            n_projections=n_projections,
            projection_method='random',  # Use random projections 
            optimization_steps=0,        # No optimization steps
            correlation_aware=False,     # No correlation awareness
            learning_rate=learning_rate,
            momentum=momentum
        )
    
    def _optimize_projections(self, source, target):
        """
        Override projection optimization to use identity or random matrices without optimization.
        """
        dim = source.shape[1]
        
        # Initialize or use existing projections
        if self.projections is None or self.projections.shape[1] != dim:
            self.projections, self.projection_weights = self._init_projections(dim)
        
        # No optimization, just use random or identity projections
        projections = self.projections.copy()
        weights = np.ones(self.n_projections) / self.n_projections  # Equal weights
        
        return projections, weights
    
    def fit(self, source, target):
        """
        Override fit to avoid learning optimized projections.
        """
        # Clean inputs
        source = np.nan_to_num(source, nan=0.0, posinf=1e10, neginf=-1e10)
        target = np.nan_to_num(target, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # Initialize projections if needed
        dim = source.shape[1]
        if self.projections is None or self.projections.shape[1] != dim:
            self.projections, self.projection_weights = self._init_projections(dim)
        
        self.fitted = True


class ESCORTNoProj(ImprovedESCORT):
    """
    ESCORT variant without optimized projection matrices.
    Uses isotropic kernel without projection-based correlation awareness.
    """
    def __init__(self, n_particles=100, state_dim=2, kernel_bandwidth=0.1, 
                step_size=0.01, lambda_corr=0.1, lambda_temp=0.1, 
                n_projections=10, learning_rate=1e-3, device=None, 
                discrete_actions=True, verbose=False):
        """
        Initialize the ESCORT-NoProj agent with modified GSWD component.
        """
        # Initialize base components
        self.n_particles = n_particles
        self.state_dim = state_dim
        self.lambda_corr = lambda_corr
        self.lambda_temp = lambda_temp
        self.verbose = verbose
        
        # Initialize device
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create modified components for belief representation
        self.kernel = ImprovedRBFKernel(
            bandwidth=kernel_bandwidth, 
            adaptive=True,
            bandwidth_scale=0.5 * state_dim**0.5  # Scale with dimensionality
        )
        
        # Use NoProjectionGSWD instead of standard ImprovedGSWD
        self.gswd = NoProjectionGSWD(
            n_projections=max(n_projections, state_dim),
            projection_method='random',
            optimization_steps=0,
            correlation_aware=False
        )
        
        # Initialize the SVGD updater
        self.svgd = ImprovedSVGD(
            kernel=self.kernel,
            gswd=self.gswd,
            step_size=step_size,
            lambda_corr=lambda_corr,
            max_iter=50,
            tol=1e-5,
            verbose=verbose
        )
        
        # Initialize particles with better coverage for multi-modal distributions
        self.particles = self._initialize_particles()
        
        # Store previous belief for temporal consistency
        self.prev_particles = self.particles.copy()
        self.prev_action = None
        self.prev_observation = None
        
        # Initialize action dimension
        self.action_dim = None
        
        # Flag to track if policy is initialized
        self.policy_initialized = False
        
        # Set up policy network parameters
        self.hidden_dim = 128
        self.learning_rate = learning_rate
        self.discrete_actions = discrete_actions
        
        # To store policy network
        self.policy_network = None
        self.optimizer = None
        
        # For tracking performance
        self.training_loss = []
        
        if verbose:
            print("Initialized ESCORT-NoProj (projection matrices disabled)")
