import numpy as np
from src.bandits.NonStationaryBandit import NonStationaryBandit
from sklearn.gaussian_process.kernels import Matern, RBF
from sklearn.metrics.pairwise import pairwise_kernels

def kernel_matrix(A, B, config):
    name = config['name']
    C = config.copy()
    del C['name']
    if name == 'matern':
        kernel = Matern(**C)
        return pairwise_kernels(A, B, metric=kernel)
    elif name == 'rbf':
        kernel = RBF(**C)
        return pairwise_kernels(A, B, metric=kernel)
    raise Exception(f'Invalid kernel: {name}')


def compute_rkhs_norm_numpy(function_values, K):
   
    f = np.array(function_values).reshape(-1, 1) 
    epsilon = 1e-14 
    K += np.eye(K.shape[0]) * epsilon

    L = np.linalg.cholesky(K)
    
    trig = np.linalg.solve(L, f)
    
    rkhs_norm = np.sqrt(np.sum(trig**2))
    return rkhs_norm



class NSKernelBandit(NonStationaryBandit):
    def __init__(
        self,
        num_actions,
        noise_variance,
        d,
        mean_low,
        mean_high,
        reward_bound,
        kernel_config={'name': 'rbf', 'length_scale': 0.2},
        reward_generation_method='random_reward',
        rkhs_index=1,
        continuous=False
    ):
        self.reward_means=None

        self.d = d
        self.kernel_config = kernel_config
        self.num_points = 200
        self.reward_generation_method = reward_generation_method
        self.rkhs_index = rkhs_index 
        super().__init__(
            arm_type="normal",
            num_actions=num_actions,
            noise_variance=noise_variance,
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            continuous=continuous
        )
        self.target_reward_function = None
        self.init_reward_function = None

        self.init_params = {
            'num_actions': num_actions,
            'noise_variance': noise_variance,
            'd': d,
            'mean_low': mean_low,
            'mean_high': mean_high,
            'reward_bound':reward_bound,
            'continuous': continuous
        }

    def set_arms(self):
      

        arms = np.random.randn(self.num_actions, self.d)
        arms /= np.sqrt(np.square(arms).sum(axis=1))[:, np.newaxis]


        
        return arms

    def get_reward_means(self):
        if self.reward_generation_method == 'kernel_sum':
            self.reward_means=self.generate_kernel_reward()
            return self.reward_means
        elif self.reward_generation_method == 'random_reward':
            self.reward_means=self.generate_random_reward()
            return self.reward_means
        else:
            raise ValueError('Invalid reward generation method.')

   


    def generate_kernel_reward(self, c=1.6):
        num_points = self.num_points
        B = self.reward_bound 

        if not hasattr(self, 'sample_points'):
            self.sample_points = np.random.uniform(-1, 1, size=(num_points, self.d))
            self.kern_alpha = kernel_matrix(self.sample_points, self.sample_points, self.kernel_config)
            self.alpha = np.random.uniform(-1, 1, size=(num_points, 1))
            norm_alpha = np.sqrt(self.alpha.T @ self.kern_alpha @ self.alpha)
            if norm_alpha > B:
                self.alpha *= B / norm_alpha
        
        sample_points = self.sample_points
        kern_alpha = self.kern_alpha
        alpha_t = self.alpha.copy()

        delta_alpha = np.zeros_like(alpha_t)
        num_changes = min(5, num_points) 
        indices_to_change = np.random.choice(np.arange(num_points), size=num_changes, replace=False)
        delta_alpha[indices_to_change] = np.random.randn(num_changes, 1)

        delta_n = delta_alpha.T @ kern_alpha @ delta_alpha

        if delta_n == 0:
            return self.generate_kernel_reward(c)

        change_in_norm=np.random.uniform(0.1,c)
        s = change_in_norm / np.sqrt(delta_n)
        delta_alpha_scaled = s * delta_alpha

        alpha_t_plus_1 = alpha_t + delta_alpha_scaled

        norm_f_t_plus_1_squared = alpha_t_plus_1.T @ kern_alpha @ alpha_t_plus_1

        if norm_f_t_plus_1_squared <= B**2:
            self.alpha = alpha_t_plus_1
        else:
            a = delta_alpha_scaled.T @ kern_alpha @ delta_alpha_scaled
            b = 2 * alpha_t.T @ kern_alpha @ delta_alpha_scaled
            c_eq = (alpha_t.T @ kern_alpha @ alpha_t) - B**2

            coeff_a = a
            coeff_b = b
            coeff_c = c_eq

            discriminant = coeff_b**2 - 4 * coeff_a * coeff_c
            if discriminant >= 0:
                sqrt_disc = np.sqrt(discriminant)
                s1 = (-coeff_b + sqrt_disc) / (2 * coeff_a)
                s2 = (-coeff_b - sqrt_disc) / (2 * coeff_a)
                s_adjusted = min(s1, s2, key=lambda x: abs(x)) if s1 > 0 else s2
                delta_alpha_adjusted = s_adjusted * delta_alpha_scaled
                self.alpha = alpha_t + delta_alpha_adjusted
            else:
                norm_alpha_t_squared = alpha_t.T @ kern_alpha @ alpha_t
                max_delta_norm_squared = B**2 - norm_alpha_t_squared
                if max_delta_norm_squared > 0:
                    max_delta_norm = np.sqrt(max_delta_norm_squared)
                    delta_alpha_scaled *= max_delta_norm / np.sqrt(delta_n)
                    self.alpha = alpha_t + delta_alpha_scaled
                else:
                    self.alpha = alpha_t

        kern_vals = kernel_matrix(self.arms, sample_points, self.kernel_config)
        f = kern_vals @ self.alpha

        return f.flatten()

   

    def generate_random_reward(self):
       
        prev_reward=self.reward_means


       
        K = kernel_matrix(self.arms, self.arms, self.kernel_config)

        m = np.random.multivariate_normal(np.zeros(self.num_actions), K)


        m /= np.max(np.abs(m))
        m *= self.reward_bound
        if prev_reward is not None:
            arms_to_change=np.random.randint(1,self.num_actions)
            arms = np.random.choice(np.arange(self.num_actions), size=arms_to_change, replace=False)
            ctr=0
            while(np.max(np.abs(m[arms]-prev_reward[arms]))<0.1 and ctr<500):
                m = np.random.multivariate_normal(np.zeros(self.num_actions), K)
              
                m /= np.max(np.abs(m))
                m *= self.reward_bound
            prev_reward[arms]=m[arms]
            m=np.copy(prev_reward)
        return m
    def share_reward_params(self):
        return self.R,self.lamda,self.B
    


    
 

    def get_mean_reward(self, action):
        if self.reward_means is None:
            self.reward_means = self.get_reward_means()
        return float(self.reward_means[action])
    


    def abrupt_change(self):
        self.reward_means = self.get_reward_means()

    def gradual_change(self, change_rate=0.01):

        if not hasattr(self, "target_reward_function") or self.target_reward_function is None:
            self.init_reward_function = self.reward_means.copy()
            self.target_reward_function = self.generate_kernel_reward()            

        self.reward_means = (1 - change_rate) * self.init_reward_function + change_rate * self.target_reward_function


        


    def re_init(self):
        super().re_init()
