import numpy as np
from scipy.optimize import minimize
from scipy.stats import norm
import os
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Kernel, RBF


class Matern52Kernel(Kernel):
    def __init__(self, length_scale=None):
        self.length_scale = length_scale
    
    def __call__(self, X, Y=None):
        if Y is None:
            Y = X
        
        X = np.atleast_2d(X)
        Y = np.atleast_2d(Y)
        
        n_samples_X = X.shape[0]
        n_samples_Y = Y.shape[0]
        result = np.zeros((n_samples_X, n_samples_Y))
        
        for i in range(n_samples_X):
            for j in range(n_samples_Y):
                d = 0
                for k in range(X.shape[1]):
                    scale = self.length_scale[k] if self.length_scale is not None else 1.0
                    d += ((X[i, k] - Y[j, k]) / scale) ** 2
                d = np.sqrt(d)
                result[i, j] = (1 + np.sqrt(5) * d + 5 * d**2 / 3) * np.exp(-np.sqrt(5) * d)
        
        return result
    
    def diag(self, X):
        return np.ones(X.shape[0])
    
    def is_stationary(self):
        return True

    def get_params(self, deep=True):
        return {"length_scale": self.length_scale}

    def set_params(self, **params):
        if "length_scale" in params:
            self.length_scale = params["length_scale"]
        return self

class BayesianWeightOptimizer:
    def __init__(self, weights_names, initial_weights=None):

        self.weights_names = weights_names
        self.n_weights = len(weights_names)
        self.initial_weights = initial_weights if initial_weights is not None else [0.1] * self.n_weights
        self.X = []  
        self.y = []  
        self.length_scale = [1.0] * self.n_weights  
        
        self.kernel = Matern52Kernel(length_scale=self.length_scale)
        self.gp = GaussianProcessRegressor(
            kernel=self.kernel,
            alpha=1e-6,
            normalize_y=True,
            n_restarts_optimizer=5
        )
        

        self.bounds = None
        
    def acquisition_function(self, x, X, y, beta=2.0):
        if len(X) == 0:
            return 1.0
        
    
        x = np.atleast_2d(x)
        X = np.atleast_2d(X)
        y = np.atleast_1d(y)
        
        try:
            K = self.kernel(X, X)
            K = K + 1e-6 * np.eye(len(X))  
            
            k = self.kernel(X, x).reshape(-1, 1)
            
            K_inv = np.linalg.inv(K)
            mu = float(k.T @ K_inv @ y)
            sigma2 = float(self.kernel(x, x) - k.T @ K_inv @ k)
            
            if sigma2 <= 0:
                return 0.0
            
            sigma = np.sqrt(sigma2)
            
            y_star = np.max(y)
            x_star = X[np.argmax(y)]
            
            z = (mu - y_star) / sigma
            
            phi_z = norm.pdf(z)
            Phi_z = norm.cdf(z)
            h_z = phi_z + z * Phi_z
               
            ei = sigma * h_z
            
            weighted_dist = 0
            for j in range(len(x[0])):
                lambda_j = 1.0 if self.length_scale is None else self.length_scale[j]
                weighted_dist += lambda_j * (x[0,j] - x_star[j])**2
            w_x = np.exp(-weighted_dist)
            
            return float(w_x * ei)
            
        except Exception as e:
            print(f"Compute EI error: {str(e)}")
            print(f"Input demension - x: {x.shape}, X: {X.shape}, y: {y.shape}")
            raise
    
    def optimize(self, current_score=None, previous_weights=None, beta=2.0, exploration_weight=0.1, 
                extra_metrics=None, current_success_rate=None):
        if current_score is not None and previous_weights is not None:
            
            try:
                if isinstance(previous_weights, dict):
                    self.weights_names = list(previous_weights.keys())
                    self.X.append([float(previous_weights[name]) for name in self.weights_names])
                else:
                    self.X.append([float(v) for v in previous_weights])
                self.y.append(float(current_score))
                
                if len(self.X) > 0:
                    X_array = np.array(self.X)
                    y_array = np.array(self.y)
                    self.gp.fit(X_array, y_array)
            except Exception as e:
                print(f"Error: {e}")
                raise
        
        def objective(x):
            if len(self.X) > 0:
                x = np.atleast_2d(x)  
                return -self.acquisition_function(x, np.array(self.X), np.array(self.y))
            return 0.0
        
        n_restarts = 5
        best_x = None
        best_value = float('inf')
        
        x0 = np.array(self.X[-1]) if len(self.X) > 0 else self.initial_weights
        
        for i in range(n_restarts):
            if i > 0:
                noise = np.random.normal(0, 0.2, size=len(x0))
                x0_noisy = x0 + noise
            else:
                x0_noisy = x0
            
            res = minimize(
                objective,
                x0=x0_noisy,
                method='BFGS',
                options={'eps': 1e-3}  
            )
            
            if res.fun < best_value:
                best_value = res.fun
                best_x = res.x
        
        next_weights = {}
        for name, new_value in zip(self.weights_names, best_x):
            try:
                next_weights[name] = float(new_value)
            except (ValueError, TypeError) as e:
                if isinstance(new_value, (int, float, np.number)):
                    next_weights[name] = float(new_value)
                else:
                    next_weights[name] = new_value
        
        return next_weights, -best_value

    def _adjust_weights_dimension(self, weights):
        if len(weights) > self.n_weights:
            return weights[:self.n_weights]
        else:
            temp = np.zeros(self.n_weights)
            temp[:len(weights)] = weights
            return temp

    def initialize_bounds(self, n_dims):
        self.bounds = [(0.05, 0.8) for _ in range(n_dims)]  