import numpy as np
from sklearn.linear_model import Lasso,LinearRegression,LassoCV
import tqdm
### Base Algorithm (BSLB)
class BaseAlgorithm():
    def __init__(self, N,d, E, T,arm_vectors,true_parameter=None,true_rewards=None,linear_model=None):
        self.N = N  # Number of arms
        self.E = E  # Exploration period
        self.T = T  # Total time period
        self.t = 0  # Current time step
        self.estimate = np.zeros(d)
        
        self.arm_vectors = arm_vectors
        self.true_parameter = true_parameter
        self.true_rewards = true_rewards
        self.linear_model = linear_model
        self.initialize()
    def initialize(self):
        self.t = 0
        self.available_arms = np.arange(self.N)
        self.estimates = np.zeros(self.N)
        self.exploration_arms = []
        self.rewards = np.zeros(self.N)
        self.order_of_arms = np.zeros(self.T)
        if self.true_parameter is not None:
            self.true_rewards = self.arm_vectors@self.true_parameter
    def make_decision(self,return_arm=False):
        ### benchmark time taken in each step
        
        if self.t==self.E:
            self.compute_estimates()
            
        if self.t < self.E:
            # Exploration phase
            selected_arm = self.explore()
            ### remove the selected arm from the available arms 
        else:
            # Exploitation phase
            selected_arm = self.exploit()
        
        self.pull_arm(selected_arm)
        
        return selected_arm
        

    def explore(self):
        # Simple 
        return np.random.choice(self.available_arms)

    def exploit(self):
        
        predictions = self.model.predict(self.arm_vectors[:,self.non_zero_coef_indices])
        sorted_predictions = np.argsort(predictions)[::-1]
        predictions_which_are_available = [i for i in sorted_predictions if i in self.available_arms]
        return predictions_which_are_available[0]
    
    def get_reward(self,arm):
        if self.true_parameter is not None:
            return self.arm_vectors[arm]@self.true_parameter + np.random.normal(0,0.1)
        else:
            return self.true_rewards[arm]
        
    def pull_arm(self,arm):
        self.t += 1
        
        self.available_arms = np.delete(self.available_arms, np.where(self.available_arms == arm))
        self.rewards[arm] = self.get_reward(arm)
        self.order_of_arms[self.t-1] = arm
        if self.t<=self.E:
            self.exploration_arms.append(arm)


    def compute_estimates(self):
        if self.linear_model is None:
            self.n_explore = len(self.exploration_arms)
            model = Lasso(alpha = 0.01,max_iter=10000)     
            model.fit(self.arm_vectors[self.exploration_arms],self.rewards[self.exploration_arms])
            non_zero_coef_indices = np.nonzero(model.coef_!=0)[0]
            model_lr = LinearRegression()
            model_lr.fit(self.arm_vectors[self.exploration_arms][:,non_zero_coef_indices],self.rewards[self.exploration_arms])
            self.model = model_lr
            self.non_zero_coef_indices = non_zero_coef_indices
        else:
            model = LinearRegression()
            model.fit(self.arm_vectors,self.rewards)
            self.model = model
            self.non_zero_coef_indices = np.nonzero(model.coef_!=0)[0]
    def run(self):
        for t in range(self.T):
            self.make_decision()
from scipy import optimize
### CORRAL

import numpy as np
from scipy.optimize import minimize_scalar

def LOG_BARRIER_OMD(p, loss, eta):
    """
    Implement a robust version of the LOG-BARRIER-OMD algorithm.
    
    Args:
    p_t (np.array): Previous distribution (trusts)
    l_t (np.array): Current loss vector (losses)
    eta_t (np.array): Learning rate vector (rates)
    
    Returns:
    np.array: Updated distribution p_{t+1} (new_trusts)
    """
    min_loss = max(0, np.min(loss))
    max_loss = np.max(loss)
    
    def objective(lambda_val):
        """Objective function to minimize."""
        lhs = np.sum(1 / (1/p + eta * (loss - lambda_val)))
        return (lhs - 1) ** 2

    assert min_loss <= max_loss, f"Error: invalid loss interval [{min_loss:.3g}, {max_loss:.3g}]"
    
    result = minimize_scalar(objective, bounds=(min_loss, max_loss), method='bounded')
    lambda_val = result.x
    
    assert min_loss <= lambda_val <= max_loss, f"Error: λ={lambda_val:.3g} not in [{min_loss:.3g}, {max_loss:.3g}]"
    
    p_t_plus_1 = 1 / (1/p + eta * (loss - lambda_val))
    p_t_plus_1 /= np.sum(p_t_plus_1)
    
    assert np.isclose(np.sum(p_t_plus_1), 1), f"Error: new distribution {list(p_t_plus_1)} doesn't sum to 1"
    
    if not np.all(p_t_plus_1 >= 0):
        print(f"Warning: new distribution {list(p_t_plus_1)} contains negative values. Adjusting...")
        x = np.min(p_t_plus_1)
        p_t_plus_1 /= np.abs(x)
        p_t_plus_1 += 1
        p_t_plus_1 /= np.sum(p_t_plus_1)
    
    assert np.all(p_t_plus_1 >= 0) and np.all(p_t_plus_1 <= 1), f"Error: invalid probability distribution {list(p_t_plus_1)}"
    
    return p_t_plus_1

def CORRAL(learning_rate, base_algorithms, T):
    M = len(base_algorithms)
    
    # Initialize
    gamma = 1/T
    beta = np.exp(1/(np.log(T)))
    eta = np.full(M, learning_rate)
    rho = np.full(M, 2*M)
    p = np.full(M, 1/M)
    p_bar = np.full(M, 1/M)
    # Initialize all base algorithms (assuming they have an initialize method)
    p_store = np.zeros((T,M))
    order_of_arms = np.zeros(T)
    for t in range(1, T+1):
        # Observe side information x_t, send x_t to B_i and receive decision theta_i for each i
        # Sample i_t, predict theta_t, observe loss
        decisions = [base_algorithms[i].make_decision() for i in range(M)]
        
        i_t = np.random.choice(M, p=p_bar)
        selected_arm = decisions[i_t]
        loss = -base_algorithms[i_t].get_reward(selected_arm)
        order_of_arms[t-1] = selected_arm
        
        # Send feedback to base algorithms
        for i in range(M):
            feedback = (loss / p_bar[i]) if i == i_t else 0
            base_algorithms[i].rewards[selected_arm] = -feedback
        
        # Update p_t+1 using LOG-BARRIER-OMD
        e_it = np.zeros(M)
        e_it[i_t] = 1
        p = LOG_BARRIER_OMD(p, (loss / p_bar[i_t]) * e_it, eta)
        p_store[t-1,:] = p
        # Set p_bar_t+1
        p_bar = (1 - gamma) * p + gamma * (1/M)
        
        # Update rho and eta
        for i in range(M):
            if 1/p_bar[i] > rho[i]:
                rho[i] = 2 / p_bar[i]
                eta[i] *= beta
          
    
    return p_store, p_bar, order_of_arms
