import numpy as np
import cvxpy as cp
from itertools import product
import math
from itertools import combinations
import time

class OnlineGradientDescent:
    def __init__(self, learning_rate, diameter = 2):
        self.learning_rate = learning_rate
        self.diameter = diameter
        self.sum_norm = 1e-8

    def update(self, w, gradient, adaptive_flag = 0, project = False, project_half=False):
        if adaptive_flag==0:
            w_ = w - self.learning_rate * gradient
            if project:
                w_norm = np.linalg.norm(w_, 'fro')
                if project_half:
                    if w_norm > self.diameter/2:
                        w_ = (w_/w_norm)*self.diameter/2
                else: 
                    if w_norm > self.diameter:
                        w_ = (w_/w_norm)*self.diameter
            return w_
        else:
            self.update_learning_rate(gradient)
            w_ = w - self.learning_rate * gradient
            if project:
                w_norm = np.linalg.norm(w_, 'fro')
                if project_half:
                    if w_norm > self.diameter/2:
                        w_ = (w_/w_norm)*self.diameter/2
                else: 
                    if w_norm > self.diameter:
                        w_ = (w_/w_norm)*self.diameter
            return w_
    
    def update_learning_rate(self, gradient):
        self.sum_norm += np.linalg.norm(gradient, 'fro')**2
        self.learning_rate = self.diameter/np.sqrt(2*self.sum_norm)

class OnlineStructuredPrediction_bandit_multilabel_fixed_self:
    def __init__(self, k, d, m, V, b, c, q):
        self.k = k
        self.d = d
        self.m = m
        self.b = b
        self.c = c
        self.q = q
        self.V = V
        self.W = np.zeros([self.d, self.k])
        
        self.round = 0
        
        indices = list(combinations(range(d), m))
        self.object_matrix = np.zeros((len(indices), d), dtype=int)
        for i, idx in enumerate(indices):
            self.object_matrix[i, list(idx)] = 1 
        self.matrix_matrix = np.matmul(self.object_matrix.reshape(math.comb(d,m),self.d,1), self.object_matrix.reshape(math.comb(d,m),1,self.d))
        
        self.Q = (self.m * (self.d-self.m) * np.identity(self.d) + self.m * (self.m-1) * np.ones((self.d,self.d)) ) / (self.d*(self.d-1))
    
    def prediction(self, x, rd):
        self.x = x
        theta = self.W @ self.x
        self.y_hat, self.y_omega, self.probability = rd.randomized_decoding(theta)
        return self.y_hat
    
    def update(self, loss, alg, step, project=False, project_half=False):
        grad = self.cal_grad(loss, step)
        self.W = alg.update(self.W, grad, 1, project=project, project_half=project_half)
    
    
    def cal_grad(self,loss,step):
        index = np.nonzero(self.probability)[0]
        P_t = self.q * self.Q + (1-self.q) * np.tensordot(self.probability[index], self.matrix_matrix[index], axes=1)
        
        inner_product = loss - self.y_hat@self.b - self.c
        ytilde = np.linalg.solve(P_t @ self.V, inner_product * self.y_hat)
        return np.outer((self.y_omega-ytilde), self.x)


def frank_wolfe(vertices, theta,K, m ,max_iter=1000):
    lambdas = np.zeros(K)
    index = np.random.default_rng().integers(K)
    
    y_current = vertices[index]
    for _ in range(max_iter):
        gradient =  y_current - theta
        
        indices = np.argsort(gradient)[:m]
        s_k = np.zeros_like(gradient)
        s_k[indices] = 1
        
        idx = np.where(np.all(vertices == s_k, axis=1))[0]
        alpha = 2 / (_ + 2) 
        
        lambdas = (1 - alpha) * lambdas
        lambdas[idx] += alpha
        
        y_current = y_current + alpha * (s_k - y_current)
    return y_current, lambdas

class RandomizedDecoding_multilabel_bandit_fixed:
    def __init__(self, q, d,  m, nu=1):
        self.nu = nu
        self.q = q
        self.d = d
        self.m = m
        self.K = math.comb(self.d,self.m)
        self.binary_vectors = np.array([np.array(v) for v in combinations(range(self.d), self.m)])
        self.A = np.zeros((self.d, len(self.binary_vectors)))
        for i, indices in enumerate(self.binary_vectors):
            self.A[indices, i] = 1
        self.A = self.A.transpose()
        
        self.counter = 0
    
    def randomized_decoding(self, theta):
        probability = np.random.rand()
        y_hat, y_omega, output_probability = self.randomized_decoding_sub(theta)
        if probability < self.q :
            index = np.random.default_rng().integers(0,self.K)
            y_hat = np.zeros(self.d)
            y_hat[self.binary_vectors[index]] = 1
            return y_hat.astype(int), y_omega, output_probability
        else:
            return y_hat.astype(int), y_omega, output_probability
    
    def randomized_decoding_sub(self, theta):
        y_omega, c = frank_wolfe(self.A, theta, self.K, self.m, max_iter=100)
        y_star = (np.round(self.cal_nearest_extreme_point(y_omega))).astype(int)
        delta_star = np.linalg.norm(y_star - y_omega)
        p = min(1, 2 * delta_star / self.nu)
        if p==0 or p==1:
            self.counter += 1
        probability = np.random.rand()
        c = c / c.sum()
        one_hot_vector = np.zeros(self.K)
        one_hot_vector[np.where(np.all(self.A == y_star, axis=1))[0]] = 1
        output_probability = p * c + (1 - p) * one_hot_vector
        if probability < p:
            index = np.random.choice(self.K, p=c)
            y_tilde = np.zeros(self.d)
            y_tilde[self.binary_vectors[index]] = 1
            return y_tilde, y_omega, output_probability
        else:
            return y_star, y_omega, output_probability
        
    def cal_convex_combination(self, y_omega):
        x = cp.Variable(self.K)
        objective = cp.Minimize(0)
        constraints = [
            x >= 0,
            x @ np.ones(self.K) == 1,
            self.A @ x == y_omega]
        problem = cp.Problem(objective, constraints)
        result = problem.solve(solver=cp.SCS, eps=1e-3, max_iters = 1000)
        x.value[x.value<0]=0
        return x.value
        
    def indicator_func(self,y_1,y_2):
        if np.array_equal(y_1,y_2):
            return 1
        else:
            return 0
    
    def regularized_prediction_function(self, theta):
        x = cp.Variable(self.d)
        objective = cp.Minimize(-theta@x+0.5*(cp.norm2(x)**2))
        constraints = [
            x >= 0,
            x <= 1, 
            cp.sum(x) == self.m
        ]
        problem = cp.Problem(objective, constraints)
        result = problem.solve()
        return x.value
    
    def cal_nearest_extreme_point(self, y_omega):
        x = cp.Variable(self.d)
        s = (1 - y_omega) ** 2 - (y_omega) ** 2
        objective = cp.Minimize(s@x)
        constraints = [
            x >= 0,
            x <= 1, 
            cp.sum(x) == self.m
        ]
        problem = cp.Problem(objective, constraints)
        result = problem.solve()
        if np.all(x.value<0.9):
            x = cp.Variable(self.d, integer=True)
            objective = cp.Minimize(s@x)
            constraints = [
                x >= 0,
                x <= 1, 
                cp.sum(x) == self.m
            ]
            problem = cp.Problem(objective, constraints)
            result = problem.solve()
        return x.value
    
def hamming_loss(y_1,y_2):
    loss = 0
    d = len(y_1)
    for i in range(len(y_1)):
        if y_1[i] != y_2[i]:
            loss += 1
    return loss/d

def run_bandit_multilabel_fixed(Y, d, m,  X, q, diameter, normed = True, project = False, project_half = False):
    k = len(X[0])
    V = -2 * np.identity(d)/d
    b = np.ones(d)/d
    c = m/d
    osp, rd = OnlineStructuredPrediction_bandit_multilabel_fixed_self(k,d,m,V,b,c,q), RandomizedDecoding_multilabel_bandit_fixed(d=d, m=m, nu=1, q=q)
    ogd = OnlineGradientDescent(learning_rate=0, diameter=diameter)
    xnorm = np.max(np.linalg.norm(X, axis = 1))
    loss_list = []
    
    for step in range(len(Y)):
        if normed:
            x = X[step]/xnorm
        else:
            x = X[step]
        y = Y[step]
        y_hat = osp.prediction(x, rd)
        loss = hamming_loss(y_hat, y)
        loss_list.append(loss)
        osp.update(loss, ogd, step, project=project, project_half=project_half)
    return loss_list