from pymdp.maths import spm_log_single, softmax
import numpy as np
import torch
from botorch.acquisition import AcquisitionFunction
from matplotlib import pyplot as plt

class BOBA:
    def __init__(self, Y_best, model, num_parameters, num_observation_levels, current_time, beta, BOBA_normalization):
        self.Y_best = Y_best
        self.dims = num_parameters
        self.model = model
        self.input_dims = self.model.train_inputs[0].shape[1]
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.num_observation_levels = num_observation_levels
        self.current_time = current_time
        self.grid_size = int(np.ceil(np.power(262144, 1/(self.input_dims-1))))  # Reduced from 100 to 50
        self.beta = beta
        self.norm = BOBA_normalization
        # Cache for GP predictions
        self.prediction_cache = {}
        self.grid_cache = None
        self.batch_size = 10000  # Adjust based on available memory
        self.A = self.create_A_matrix()
        

    def create_A_matrix(self):
        """
        Create observation likelihood matrix using optimized GP predictions
        """
        # Check if we have cached predictions for this time
        cache_key = f"{self.current_time:.4f}"
        if cache_key in self.prediction_cache:
            return self.prediction_cache[cache_key]

        # Get input dimension from the model's training inputs
        # Create grid points for each dimension with reduced resolution
        grid_tensors = []
        for dim in range(self.input_dims - 1):  # Exclude time dimension
            grid_tensors.append(torch.linspace(0, 1, self.grid_size, device=self.device))
        # Add time dimension separately
        grid_tensors.append(torch.tensor([self.current_time], device=self.device, dtype=torch.float32))

        # Create meshgrid for all dimensions
        grid_meshes = torch.meshgrid(grid_tensors, indexing='ij')

        # Stack all dimensions into a single tensor
        X_test_full = torch.stack([grid.flatten() for grid in grid_meshes], dim=1).to(self.device)

        # Cache the grid for future use
        self.grid_cache = X_test_full

        # Get posterior predictions in batches
        self.model.eval()
        with torch.no_grad():
            # Process in batches to avoid memory issues
            num_points = X_test_full.shape[0]
            self.mean = torch.zeros(num_points, device=self.device)
            self.stddev = torch.zeros(num_points, device=self.device)
            
            for i in range(0, num_points, self.batch_size):
                batch_end = min(i + self.batch_size, num_points)
                batch = X_test_full[i:batch_end]
                posterior = self.model.posterior(batch)
                self.mean[i:batch_end] = posterior.mean.squeeze()
                self.stddev[i:batch_end] = posterior.variance.sqrt().squeeze()

        # Calculate probabilities using vectorized operations
        threshold = torch.tensor(self.Y_best, device=self.device) #0.675 looks for the top 25% of the code

        normal_dist = torch.distributions.normal.Normal(self.mean, self.stddev)
        prob_0 = 1 - normal_dist.cdf(threshold)
        # Check if the first row of prob_0 is all zeros
        # Create A matrix
        A = torch.stack([
            prob_0,
            1 - prob_0
        ]).cpu().numpy()
        # Find indices where A values are 1.0
        #if A[0,:].sum() == 0.0:
            # Extract the corresponding mean and stddev values
            # mean_values = mean[indices]
            # stddev_values = stddev[indices]

            # Print the mean and stddev values for debugging purposes
            # print(f"Mean values at indices where A is 1.0: {mean_values}")
            # print(f"Stddev values at indices where A is 1.0: {stddev_values}")
        # Cache the results
        self.prediction_cache[cache_key] = A

        return A
    
    def extrinsic_value_closed_form(self):
        """
        Compute the extrinsic value (KL divergence between Q(o|x) and preferred distribution) at each candidate.

        Parameters:
        - X_candidates: [N x D] array of input locations
        - gp_model: trained GP model with predict(..., return_std=True)
        - preferred_mean: scalar, preferred outcome (e.g., best seen value)
        - preferred_std: scalar, std of preferred distribution (tolerance)

        Returns:
        - extrinsic_values: [N] KL divergence values
        """
        var = self.stddev.cpu().numpy()**2
        mean = self.mean.cpu().numpy()
        std = self.stddev.cpu().numpy()
        noise_variance = self.model.likelihood.noise.item()
        kl = np.log(np.sqrt(noise_variance) / std) + \
            (var + (mean - self.Y_best)**2) / (2 * noise_variance) - 0.5


        return kl

    def intrinsic_value(self):
        """
        Efficient computation of intrinsic value (epistemic term) over all candidates.
        
        Parameters:
        - X_candidates: [N x D] numpy array of candidate inputs (e.g., 250,000 x D)
        - gp_model: trained GP model (e.g., from scikit-learn)
        - noise_variance: scalar (likelihood variance)
        
        Returns:
        - intrinsic_values: [N] array of mutual information values for each candidate
        """
        variances = self.stddev.cpu().numpy()**2
        noise_variance = self.model.likelihood.noise.item()
        return 0.5 * np.log1p(variances / noise_variance)
    
    def digitize_state(self, x):
        """ Digitize a state into a set of parameters """
        discretized_x = np.zeros(self.batch, dtype=int)
        discretized_x[x] = 1
        ##For previous method 
        # discretized_x = np.zeros(self.num_parameters, dtype=int)
        # for j in range(self.dims):
        #     dig = np.digitize(x[j], bins=np.linspace(0, 1, int(self.num_parameters/self.dims)))-1
        #     discretized_x[int(self.num_parameters/self.dims)*(j)+dig] = 1
        return discretized_x
    
    def __call__(self):
        """
        Compute acquisition function value using active inference with optimized predictions
        Returns multiple points where Q_u values are above the threshold
        
        Parameters:
        - threshold: Minimum Q_u value to consider (default: 0.1)
        """
        # Clear cache if it gets too large
        if len(self.prediction_cache) > 10:  # Keep last 10 predictions
            self.prediction_cache.clear()

        G = np.zeros(self.mean.shape[0])
        
        # Calculate uncertainty and divergence
        #predicted_uncertainty = self.uncertainty(self.A, None)
        #predicted_divergence = self.kl_divergence(self.A, self.C)
        extrinsic_val = self.extrinsic_value_closed_form()
        intrinsic_val = self.intrinsic_value()
        if self.norm:
            extrinsic_val = (extrinsic_val - extrinsic_val.min()) / (extrinsic_val.max() - extrinsic_val.min())
            intrinsic_val = (intrinsic_val - intrinsic_val.min()) / (intrinsic_val.max() - intrinsic_val.min())
        
        # Combine uncertainty, divergence, and intrinsic value

        G = extrinsic_val - self.beta*intrinsic_val ###Beta value of 16
        Q_u = softmax(-G)

        max_index = np.argmax(Q_u)
        
        # Convert flat index to multi-dimensional indices
        indices = []
        remaining_index = max_index
        for _ in range(self.dims):  # self.dims is the number of parameters
            dim_size = self.grid_size ** (self.dims - len(indices) - 1)
            idx = remaining_index // dim_size
            indices.append(idx)
            remaining_index = remaining_index % dim_size
        
        # Create probe tensor with all dimensions plus time
        probe_values = [idx / (self.grid_size - 1) for idx in indices]  # Normalize to [0,1]
        probe_values.append(self.current_time)  # Add time dimension
        probe = torch.tensor(probe_values, device=self.device)
        
        return probe, intrinsic_val[max_index], extrinsic_val[max_index]

class ActiveInferenceAcquisitionFunction(AcquisitionFunction):
    def __init__(self, Y_best, model, num_parameters, num_observation_levels, current_time, beta, BOBA_normalization):
        super().__init__(model)
        self.active_inference = BOBA(
            Y_best,
            model, 
            num_parameters, 
            num_observation_levels, 
            current_time,
            beta,
            BOBA_normalization
        )
    
    def forward(self):
        return self.active_inference()






