from scipy.stats import norm
from sklearn.gaussian_process import GaussianProcessClassifier, GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.gaussian_process.kernels import ConstantKernel, Matern
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.spatial.distance import cdist


def rbf_kernel(xa, xb, sigma=1.0):
    """Radial Basis Function (RBF) kernel."""
    sq_norm = -0.5 * cdist(xa, xb, 'sqeuclidean') / sigma**2
    return np.exp(sq_norm)




class GPWithPillarReRun:
    def __init__(self,pillar_holder, x_iteration=5,  num_pillar_points=3, noise_level=0.05, sigma=1.0):
        self.noise_level = noise_level
        self.sigma = sigma
        self.K_train_inv = None
        self.point_set = np.empty((0, 1))
        self.point_scores = np.empty((0,))
        self.next_sample_point = None
        self.acquisition_function_values = None
        self.c_min = 0.0
        self.c_max = 10.0
        
        self.x_iteration = x_iteration
        self.num_pillar_points = num_pillar_points
        self.current_iteration = 0
        self.pillar_point_index = 0
        self.pillar_points = None
        self.pillar_points_values = None
        self.pillar_holder = pillar_holder

    def fit_model(self, point_set, point_scores):
        point_set = np.atleast_2d(point_set).reshape(-1, 1)  # Ensure 2D array and reshape to match new_points
        point_scores = np.atleast_1d(point_scores)  # Ensure 1D array
        K_train = rbf_kernel(point_set, point_set, self.sigma) + self.noise_level**2 * np.eye(len(point_set))
        K_train_inv = np.linalg.inv(K_train)
        self.K_train_inv = K_train_inv
        self.point_set = point_set
        self.point_scores = point_scores

    def predict(self, X_test):
        X_test = np.atleast_2d(X_test).reshape(-1, 1)  # Ensure 2D array and reshape to match point_set
        K_train_test = rbf_kernel(self.point_set, X_test, self.sigma)
        K_test = rbf_kernel(X_test, X_test, self.sigma)
        posterior_mean = K_train_test.T @ self.K_train_inv @ self.point_scores
        posterior_covariance = K_test - K_train_test.T @ self.K_train_inv @ K_train_test
        return posterior_mean.ravel(), np.sqrt(np.diag(posterior_covariance))
    
    def dynamic_range_adjustment(self, point_set, point_scores, initial_expansion_factor=1.0, decay_rate=0.95):
        # Calculate the first and third quartiles
        q1 = np.percentile(point_scores, 25)
        q3 = np.percentile(point_scores, 75)

        # Select points within the interquartile range
        iqr_points = point_set[(point_scores >= q1) & (point_scores <= q3)]

        if len(iqr_points) > 0:
            # Compute the IQM (mean of the points within the interquartile range)
            iqm = np.mean(iqr_points)

            # Compute the variance of the points within the interquartile range
            iqr_variance = np.var(iqr_points)

            # Adjust the expansion factor dynamically based on the variance
            dynamic_expansion_factor = min(1.5, max(0.1, 1.0 / (1.0 + iqr_variance)))

            # Apply the decay factor
            iteration_expansion_factor = initial_expansion_factor * (decay_rate ** self.current_iteration)
            final_expansion_factor = dynamic_expansion_factor * iteration_expansion_factor

            # Adjust the range based on IQM
            new_c_min = max(0, iqm - final_expansion_factor * np.ptp(iqr_points))
            new_c_max = iqm + final_expansion_factor * np.ptp(iqr_points)
        else:
            # If no points are found within the interquartile range, default to the original range
            new_c_min, new_c_max = self.c_min, self.c_max

        #print(f"Adjusted range: {new_c_min} to {new_c_max} with final expansion factor {final_expansion_factor}")

        return new_c_min, new_c_max


    def select_points_GP(self, point_set, point_scores,total_prob, total_obs, **kwargs):
        self.fit_model(point_set, point_scores)
        

        # Define c_max dynamically based on the variance of the rewards (point_scores)
        score_variance = np.var(point_scores)
        mean_score = np.mean(point_scores)
        #self.c_max = mean_score + 2 * np.sqrt(score_variance)  # Mean plus two standard deviations
        
        self.c_min, self.c_max = self.dynamic_range_adjustment(point_set, point_scores)
        
        if self.current_iteration % self.x_iteration == 0:
            if self.pillar_points is None:
                self.pillar_indices, self.non_pillar_indices = self.pillar_holder.resample_pillar_indices(point_set)
                self.pillar_points = point_set[self.pillar_indices]
                self.pillar_points_values = point_scores[self.pillar_indices]

            new_point = self.pillar_points[self.pillar_point_index]
            self.pillar_point_index = (self.pillar_point_index + 1) % self.num_pillar_points
            
        else:    
            new_points = np.random.uniform(self.c_min, self.c_max, 500).reshape(-1, 1)
            new_points_predictions, new_std = self.predict(new_points)
            
            
            # Adjust posterior_mean based on total_prob and total_obs
            adjusted_predictions = new_points_predictions / total_prob * total_obs
            
            
            kappa = 3.5
            lcb = adjusted_predictions - kappa * new_std
            self.acquisition_function_values = lcb
            ind_next_sample = np.argmin(lcb)
            self.next_sample_point = new_points[ind_next_sample]
            new_point = new_points[ind_next_sample]
            
            
            
        new_point_score, _ = self.predict(new_point)
        
        if self.pillar_points is None:
            self.pillar_points = point_set[self.pillar_indices]
            self.pillar_points_values = point_scores[self.pillar_indices]
        elif new_point in self.pillar_points:
            index = np.where(self.pillar_points == new_point)[0][0]
            self.pillar_points_values[index] = new_point_score
        
        self.current_iteration += 1
        return np.array([new_point])

        


        



    
