import numpy as np
import cvxpy as cp
import math
from itertools import product
from itertools import combinations

class OnlineGradientDescent:
    def __init__(self, learning_rate, diameter = 2):
        self.learning_rate = learning_rate
        self.diameter = diameter
        self.sum_norm = 1e-8

    def update(self, w, gradient, adaptive_flag = 0, project = False, project_half=False):
        if adaptive_flag==0:
            w_ = w - self.learning_rate * gradient
            if project:
                w_norm = np.linalg.norm(w_, 'fro')
                if project_half:
                    if w_norm > self.diameter/2:
                        w_ = (w_/w_norm)*self.diameter/2
                else: 
                    if w_norm > self.diameter:
                        w_ = (w_/w_norm)*self.diameter
            return w_
        else:
            self.update_learning_rate(gradient)
            w_ = w - self.learning_rate * gradient
            if project:
                w_norm = np.linalg.norm(w_, 'fro')
                if project_half:
                    if w_norm > self.diameter/2:
                        w_ = (w_/w_norm)*self.diameter/2
                else: 
                    if w_norm > self.diameter:
                        w_ = (w_/w_norm)*self.diameter
            return w_
    
    def update_learning_rate(self, gradient):
        self.sum_norm += np.linalg.norm(gradient,'fro')**2
        self.learning_rate = self.diameter/np.sqrt(2*self.sum_norm)

class OnlineStructuredPrediction_bandit:
    def __init__(self,k, d):
        self.k = k
        self.d = d
        self.W = np.zeros([self.d,self.k])
    
    def prediction(self, x, rd):
        self.x = x
        theta = self.W@x
        self.y_hat, self.y_omega, self.probability = rd.randomized_decoding(theta)
        return self.y_hat
    
    def update(self, loss, ogd, project=False, project_half=False):
        grad = self.cal_grad(loss)
        self.W = ogd.update(self.W,grad,1, project=project, project_half=project_half)
    
    def cal_grad(self,loss):
        if loss != 0:
            return np.zeros([self.d,self.k])
        else:
            predicted_grad = np.outer((self.y_omega-self.y_hat),self.x)/self.probability
            return predicted_grad


class RandomizedDecoding_multiclass_bandit:
    def __init__(self, q, nu=1):
        self.nu = nu
        self.q = q
    def randomized_decoding(self, theta):
        d = len(theta)
        probability = np.random.rand()
        y_hat, y_omega, output_probability = self.randomized_decoding_sub(theta)
        output_probability = output_probability*(1-self.q) + self.q/(2**d)
        if probability < self.q :
            index = np.random.default_rng().integers(0,d)
            y_hat = np.zeros(d)
            y_hat[index] = 1
            y_omega = self.regularized_prediction_function(theta)
            return y_hat.astype(int), y_omega, output_probability
        else:
            return y_hat.astype(int), y_omega, output_probability
        
    def randomized_decoding_sub(self, theta):
        d = len(theta)
        y_omega = self.regularized_prediction_function(theta)
        largest_entry = int(np.argmax(y_omega))
        y_star = np.zeros(d)
        y_star[largest_entry] = 1
        delta_star = np.sum(np.abs(y_star-y_omega))
        index = int(np.random.choice(d,p=y_omega))
        y_tilde = np.zeros(d)
        y_tilde[index] = 1
        p = min(1, 2*delta_star/self.nu)
        probability = np.random.rand()
        output_probability = p*y_omega + (1-p)*y_tilde
        if probability<p:
            return y_tilde, y_omega, output_probability[index]
        else:
            return y_star, y_omega, output_probability[largest_entry]
        
    def indicator_func(self,y_1,y_2):
        if np.array_equal(y_1,y_2):
            return 1
        else:
            return 0
    def regularized_prediction_function(self,theta):
        exp = np.exp(theta - np.max(theta))
        sum_exp = sum(exp)
        sigma = exp/sum_exp
        return sigma

def loss_01(y_hat,y):
    if np.array_equal(y_hat,y):
        return 0
    else:
        return 1

def run_bandit_multiclass(Y_, d,  X, q, diameter, project = False, project_half = False):
    Y = []
    k = len(X[0])
    for i in range(len(Y_)):
        a = np.zeros(d)
        a[Y_[i]] = 1
        Y.append(a.astype(int))
    osp, rd = OnlineStructuredPrediction_bandit(k,d), RandomizedDecoding_multiclass_bandit(nu=2, q=q)
    ogd = OnlineGradientDescent(learning_rate=0,diameter=diameter)
    loss_list = []
    xnorm = np.max(np.linalg.norm(X, axis = 1))
    for step in range(len(Y)):
        x = X[step]/xnorm
        y = Y[step]
        y_hat = osp.prediction(x, rd)
        loss = loss_01(y_hat, y)
        loss_list.append(loss)
        osp.update(loss, ogd, project=project, project_half=project_half)
    return loss_list


class RandomizedDecoding_multiclass_bandit:
    def __init__(self, q, nu=2):
        self.nu = nu
        self.q = q
    def randomized_decoding(self, theta):
        d = len(theta)
        probability = np.random.rand()
        y_hat, y_omega, output_probability = self.randomized_decoding_sub(theta)
        output_probability = output_probability*(1-self.q) + self.q/(d)
        if probability < self.q :
            index = np.random.default_rng().integers(0,d)
            y_hat = np.zeros(d)
            y_hat[index] = 1
            y_omega = self.regularized_prediction_function(theta)
            return y_hat.astype(int), y_omega, output_probability
        else:
            return y_hat.astype(int), y_omega, output_probability
        
    def randomized_decoding_sub(self, theta):
        d = len(theta)
        y_omega = self.regularized_prediction_function(theta)
        largest_entry = int(np.argmax(y_omega))
        y_star = np.zeros(d)
        y_star[largest_entry] = 1
        delta_star = np.sum(np.abs(y_star-y_omega))
        index = int(np.random.choice(d,p=y_omega))
        y_tilde = np.zeros(d)
        y_tilde[index] = 1
        p = min(1, 2*delta_star/self.nu)
        probability = np.random.rand()
        output_probability = p*y_omega + (1-p)*y_tilde
        if probability<p:
            return y_tilde, y_omega, output_probability[index]
        else:
            return y_star, y_omega, output_probability[largest_entry]
        
    def indicator_func(self,y_1,y_2):
        if np.array_equal(y_1,y_2):
            return 1
        else:
            return 0
    def regularized_prediction_function(self,theta):
        exp = np.exp(theta - np.max(theta))
        sum_exp = sum(exp)
        sigma = exp/sum_exp
        return sigma
    
def frank_wolfe(vertices, theta,K, m ,max_iter=1000):
    lambdas = np.zeros(K)
    index = np.random.default_rng().integers(K)
    
    y_current = vertices[index]
    for _ in range(max_iter):
        gradient =  y_current - theta
        
        indices = np.argsort(gradient)[:m]
        s_k = np.zeros_like(gradient)
        s_k[indices] = 1
        
        idx = np.where(np.all(vertices == s_k, axis=1))[0]
        alpha = 2 / (_ + 2)
        
        lambdas = (1 - alpha) * lambdas
        lambdas[idx] += alpha
        
        y_current = y_current + alpha * (s_k - y_current)
    return y_current, lambdas

class RandomizedDecoding_multilabel_bandit_fixed:
    def __init__(self, q, d,  m, nu=1):
        self.nu = nu
        self.q = q
        self.d = d
        self.m = m
        self.K = math.comb(self.d,self.m)
        self.binary_vectors = np.array([np.array(v) for v in combinations(range(self.d), self.m)])
        self.A = np.zeros((self.d, len(self.binary_vectors)))
        for i, indices in enumerate(self.binary_vectors):
            self.A[indices, i] = 1
        self.A = self.A.transpose()
        
        self.counter = 0
    
    def randomized_decoding(self, theta):
        probability = np.random.rand()
        y_hat, y_omega, output_probability = self.randomized_decoding_sub(theta)
        output_probability = output_probability*(1-self.q) + self.q/self.K
        if probability < self.q :
            index = np.random.default_rng().integers(0,self.K)
            y_hat = np.zeros(self.d)
            y_hat[self.binary_vectors[index]] = 1
            return y_hat.astype(int), y_omega, output_probability
        else:
            return y_hat.astype(int), y_omega, output_probability
    
    def randomized_decoding_sub(self, theta):
        y_omega, c = frank_wolfe(self.A, theta, self.K, self.m, max_iter=100)
        y_star = (np.round(self.cal_nearest_extreme_point(y_omega))).astype(int)
        delta_star = np.linalg.norm(y_star-y_omega)
        p = min(1, 2*delta_star/self.nu)
        c = c / c.sum()
        probability = np.random.rand()
        if probability<p:
            index = np.random.choice(self.K, p=c)
            y_tilde = np.zeros(self.d)
            y_tilde[self.binary_vectors[index]] = 1
            index = np.random.choice(self.K, p=c)
            output_probability = p*c[index] + (1-p)*self.indicator_func(y_tilde, y_star)
            return y_tilde, y_omega, output_probability
        else:
            index = np.where(np.all(self.A == y_star, axis=1))[0]
            output_probability = p*c[index] + (1-p)*1
            self.counter += 1
            return y_star, y_omega, output_probability
        
    def cal_convex_combination(self, y_omega):
        x = cp.Variable(self.K)
        objective = cp.Minimize(0)
        constraints = [
            x>=0,
            x@np.ones(self.K)==1,
            self.A@x==y_omega]
        problem = cp.Problem(objective, constraints)
        result = problem.solve()
        x.value[x.value<0]=0
        return x.value
        
    def indicator_func(self,y_1,y_2):
        if np.array_equal(y_1,y_2):
            return 1
        else:
            return 0
    
    def regularized_prediction_function(self, theta):
        x = cp.Variable(self.d)
        objective = cp.Minimize(-theta@x+0.5*(cp.norm2(x)**2))
        constraints = [
            x >= 0,
            x <= 1, 
            cp.sum(x) == self.m
        ]
        problem = cp.Problem(objective, constraints)
        result = problem.solve()
        return x.value
    
    def cal_nearest_extreme_point(self, y_omega):
        x = cp.Variable(self.d)
        s = np.abs(1 - y_omega) ** 2 - np.abs(y_omega) ** 2
        objective = cp.Minimize(s@x)
        constraints = [
            x >= 0,
            x <= 1, 
            cp.sum(x) == self.m
        ]
        problem = cp.Problem(objective, constraints)
        result = problem.solve()
        if np.all(x.value<0.9):
            x = cp.Variable(self.d, integer=True)
            objective = cp.Minimize(s@x)
            constraints = [
                x >= 0,
                x <= 1, 
                cp.sum(x) == self.m
            ]
            problem = cp.Problem(objective, constraints)
            result = problem.solve()
        return x.value

def hamming_loss(y_1,y_2):
    loss = 0
    d = len(y_1)
    for i in range(len(y_1)):
        if y_1[i] != y_2[i]:
            loss += 1
    return loss/d

def run_bandit_multilabel_fixed(Y, d, m,  X, q, diameter, normed = True, project = False, project_half = False):
    k = len(X[0])
    osp, rd = OnlineStructuredPrediction_bandit(k,d), RandomizedDecoding_multilabel_bandit_fixed(d=d, m=m, nu=1, q=q)
    ogd = OnlineGradientDescent(learning_rate=0,diameter=diameter)
    loss_list = []
    xnorm = np.max(np.linalg.norm(X, axis = 1))
    for step in range(len(Y)):
        if normed:
            x = X[step]/xnorm
        else:
            x = X[step]
        y = Y[step]
        y_hat = osp.prediction(x, rd)
        loss = hamming_loss(y_hat, y)
        loss_list.append(loss)
        osp.update(loss, ogd, project=project, project_half=project_half)
    return loss_list