# %%
import utils
from sklearn.linear_model import Lasso as LassoSklearn
from math import log, sqrt
import numpy as np
import network_lasso as nl
import networkx as nx
from scipy.linalg import sqrtm
from scipy.special import log1p


class LinUcbMulti():
    """
        LinUCB with different instances. Both Itl and cluster oracle inherit form this class
    """
    def __init__(self, l_reg= 1.0, use_ucb = True, dtype= 'float32'):
        self.l = l_reg
        self.use_ucb = use_ucb
        self.dtype = dtype
        self.n_tasks = None

    def initialize(self, bandit):
        self.std = bandit.std
        self.d = bandit.dim
        # self. n_users = bandit.n_users
        if self.n_tasks is not None:
            self.invV = 1/self.l * np.tile(np.eye(bandit.dim), (self.n_tasks,1,1))
            self.b = np.zeros((self.n_tasks, bandit.dim))
            self.theta = np.zeros((self.n_tasks, bandit.dim))
            self.counts = np.zeros(self.n_tasks, dtype= "int64")
        
        return self

    def play_arm(self, cxts, u, t):
        # TODO make it data dependent
        # TODO take sigma into account
        ucb_inds = cxts @ self.theta[u]
        if self.use_ucb:
            coef = self.compute_coef(self.counts[u], t)
            ucb_inds += self.std * coef * np.sum((cxts @ self.invV[u])*cxts, axis= 1)
        return np.argmax(ucb_inds)


    def update(self, x, y, u, t):
        self.invV[u] -= utils.rank1_update(self.invV[u], x)
        self.b[u] += y * x
        self.theta[u] = self.invV[u] @ self.b[u]
        self.counts[u] += 1
        return self
    
    def compute_coef(self, count, t):
        return sqrt(4*log(t) + self.d*log(1+count/(self.d*self.l))) + sqrt(self.l)


class LinUcbItl(LinUcbMulti):
    """
        LinUCB Independent Task Learning, i.e. each agent learns its own model without any
        connection between them
    """
    def __init__(self, l_reg = 1.0, use_ucb= True):
        super(LinUcbItl, self).__init__(l_reg, use_ucb)

    def initialize(self, bandit):
        self.n_tasks = bandit.n_users
        super(LinUcbItl, self).initialize(bandit)
        return self


class LinUcbClusterOracle(LinUcbMulti):
    def __init__(self, cluster_inds, *args, **kwargs):
        super(LinUcbClusterOracle, self).__init__(*args, **kwargs)
        self.cluster_inds = cluster_inds
        self.n_tasks = max(cluster_inds)+1

    def play_arm(self, cxts, u, t):
        c = self.cluster_inds[u]
        return super(LinUcbClusterOracle,self).play_arm(cxts, c, t)
        # return np.argmax(cxts @ self.theta[c] + radius)

    def update(self, x, y, u, t):
        c = self.cluster_inds[u]
        return super(LinUcbClusterOracle,self).update(x,y,c,t)
  

# class GOBLin():
#     """
#         GOBLin as described in the original paper
#     """
#     def __init__(self, adjacency, delta= 0.01, use_ucb = True, dtype= 'float32'):
#         self.delta = delta
#         self.laplacian = np.diag(adjacency.sum(0))-adjacency
#         self.use_ucb = use_ucb
#         self.dtype = dtype
    
#     def initialize(self, bandit):
#         self.std = bandit.std
#         self.invV = np.eye(bandit.dim*bandit.n_users, dtype= self.dtype)
#         self.detV = 1.0
#         self.b = np.zeros(bandit.dim * bandit.n_users, dtype= self.dtype)
#         self.theta = self.invV @ self.b 
#         self.Iu = np.eye(bandit.n_users, dtype= self.dtype)
#         inv_sqrt_lap_id = np.linalg.inv(sqrtm(self.laplacian + self.Iu))
#         self.invSqrtA = np.kron(inv_sqrt_lap_id, np.eye(bandit.dim))
#         return self
    
#     def play_arm(self, cxts, u, t): 
#         cxts_aug = np.kron(self.Iu[u], cxts) @ self.invSqrtA
#         indices = cxts_aug @ self.theta
#         if self.use_ucb:
#             radius = np.sum((cxts_aug @ self.invV)*cxts_aug, axis=1)
#             indices +=  (1 + self.std * sqrt(log(self.detV/self.delta)))*radius
#         return np.argmax(indices)

#     def update(self, x, y, u, t):
#         x_aug = np.kron(self.Iu[u], x) @ self.invSqrtA
#         self.b += y*x_aug
#         invV_x_aug = self.invV @ x_aug
#         denom = 1+ x_aug @ invV_x_aug
#         self.detV *= denom
#         self.invV -= np.outer(invV_x_aug, invV_x_aug)/denom
#         self.theta = self.theta = self.invV @ self.b 
#         return self

class GOBLin():
    """
        GOBLin as described in the original paper
    """
    def __init__(self, adjacency, l_reg= 1.0, delta= 0.01, use_ucb = True, dtype= 'float32'):
        self.delta = delta
        self.laplacian = np.diag(adjacency.sum(0))-adjacency
        self.use_ucb = use_ucb
        self.dtype = dtype
        self.l = l_reg
    
    def initialize(self, bandit):
        self.std = bandit.std
        self.invV = 1/self.l*np.eye(bandit.dim * bandit.n_users, dtype= self.dtype)
        self.logdetV = log(self.l)
        self.b = np.zeros(bandit.dim * bandit.n_users, dtype= self.dtype)
        self.theta = self.invV @ self.b 
        self.inv_sqrt_lap_id = np.linalg.inv(sqrtm(self.laplacian + np.eye(bandit.n_users, dtype= self.dtype)))
        return self
    
    def play_arm(self, cxts, u, t): 
        cxts_aug = np.kron(self.inv_sqrt_lap_id[u], cxts)
        indices = cxts_aug @ self.theta
        if self.use_ucb:
            radius = np.sum((cxts_aug @ self.invV)*cxts_aug, axis=1)
            # indices +=  (1 + self.std * sqrt(self.logdetV + log(1/self.delta)))*radius
            indices +=  (1 + self.std * sqrt(self.logdetV + 4*log(t)))*radius
        return np.argmax(indices)

    def update(self, x, y, u, t):
        x_aug = np.kron(self.inv_sqrt_lap_id[u], x)
        self.b += y * x_aug
        invV_x_aug = self.invV @ x_aug
        self.logdetV += log1p(x_aug @ invV_x_aug)
        self.invV -= np.outer(invV_x_aug, invV_x_aug)/(1 + x_aug @ invV_x_aug)
        self.theta = self.invV @ self.b 
        return self
        

class GraphLinUCB():
    # TODO store estimated Theta (with Laplacian) and use its rows for beta

    def __init__(self, laplacian, alpha=1.0, delta=0.01, state= 1, use_ucb = True, symmetrize= False, dtype= 'float32'):
        """_summary_

        Parameters
        ----------
        laplacian : _type_
            _description_
        a : float, optional
            regularization parameter, by default 1.0
        delta : float, optional
            high probability parameter, by default 0.01
        symmetrize : bool, optional
            whether the symmetrize the random walk laplacian, by default True
        """
        self.dtype= dtype
        self.delta = delta
        self.alpha = alpha
        self.lap = laplacian.astype(self.dtype)
        if symmetrize:
            self.lap = 0.5*(laplacian + laplacian.T)
        self.state = state
        self.user_num = len(self.lap)
        self.L = self.lap + 0.01*np.eye(self.user_num, dtype = self.dtype)
        self.use_ucb = use_ucb
        

    def initialize(self, bandit):
        self.d = bandit.dim
        self.theta = np.zeros((self.user_num, self.d), dtype= self.dtype)

        self.A = np.kron(self.L, np.eye(self.d))
        self.A_inv = np.linalg.inv(self.A)  # XXX pinv --> inv
        self.XX = np.zeros(
            (self.user_num*self.d, self.user_num*self.d), dtype= self.dtype)
        self.sigma = bandit.std
        self.cov = self.alpha*self.A
        self.bias = np.zeros((self.user_num*self.d), dtype= self.dtype)
        self.beta_list = []
        self.user_v = self.alpha * \
            np.tile(np.eye(self.d, dtype= self.dtype), (self.user_num, 1, 1))
        self.user_avg = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_ls = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_ridge = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_xx = 0.1 * \
            np.tile(np.eye(self.d, dtype= self.dtype), reps=(self.user_num, 1, 1))
        self.user_bias = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_counter = np.zeros(self.user_num, dtype= self.dtype)
        self.user_h= np.zeros((self.user_num, self.d, self.d), dtype= self.dtype)
        self.users_but_u = np.array([[i for i in range(self.user_num) if i != j]
                                     for j in range(self.user_num)], dtype="int32")

    def play_arm(self, cxts, u, t):
        self.update_beta(u) 
        X_norm_sqr = np.einsum("ij,ji->i", cxts,\
                         np.linalg.solve(self.user_h[u], cxts.T))
        index = np.dot(cxts, self.theta[u])
        if self.use_ucb: 
            index = index + self.beta * np.sqrt(X_norm_sqr)
        return np.argmax(index)

    def get_theta(self):
        return self.theta

    def update(self, x, y, u, t):
        f = np.zeros((self.user_num*self.d))
        f[u*self.d:(u+1)*self.d]=x
        self.user_v[u]+=np.outer(x, x)
        self.user_xx[u]+=np.outer(x, x)
        self.user_bias[u]+=y*x
        self.user_ls[u]= np.linalg.solve(self.user_xx[u], self.user_bias[u]) # XXX inv --> solve


        self.XX+=np.outer(f, f)
        self.bias+=y*f

        self.cov+=np.outer(f, f)
        self.theta=np.linalg.solve(self.cov, self.bias).reshape((self.user_num, self.d)) #XXX inv --> solve

        self.user_avg = self.L @ self.user_ls # XXX vectorized
        return self

    def update_beta(self, u):
        u_range = self.users_but_u[u]
        sum_A = np.sum(np.einsum("k,kij-> kij", self.L[u,u_range]**2,
                                       np.linalg.inv(self.user_xx[u_range])), axis= 0)
        self.user_h[u]=self.user_xx[u]+self.alpha**2*sum_A\
                        +2*self.alpha*self.L[u, u]*np.eye(self.d)
        a=np.linalg.det(self.user_v[u])**(1/2)
        b=np.linalg.det(self.alpha*np.eye(self.d))**(-1/2)
        d=self.sigma*np.sqrt(2*np.log(a*b/self.delta))
        if self.user_counter[u]==0:
            if self.state==1:
                self.user_avg[u]=np.dot(self.user_ls.T, self.L[u])
            else:
                self.user_avg[u]=np.dot(self.true_theta.T, self.L[u])
        else:
            if self.state==1:
                self.user_avg[u]=np.dot(self.user_ls.T, self.L[u])
            else:
                self.user_avg[u]=np.dot(self.true_theta.T, self.L[u])

        c=np.sqrt(self.alpha)*np.linalg.norm(self.user_avg[u])
        self.beta=c+d

        return self
    
class Cluster:
    def __init__(self, users, V, b, N, checks= {}):
        self.users = list(users) # a list/array of users
        self.V = V # this is the matrix M in CLUB
        self.b = b
        self.N = N
        self.invV = np.linalg.inv(self.V)
        self.theta = np.matmul(self.invV, self.b) # this is the vector w in CLUB

        # only used by SCLUB
        self.checks = checks
        # FIXME
        self.checked = len(self.users) == sum(self.checks.values())

    def update_check(self, u):
        self.checks[u] = True
        self.checked = len(self.users) == sum(self.checks.values())

class BaseCLUB(LinUcbItl):
    # TODO add sigma parameter

    def __init__(self, G, a=1.0, a2=0.01, *args, **kwargs):
        super(BaseCLUB, self).__init__(*args, **kwargs)
        self.G = G
        self.a = a
        self.n_users = G.number_of_nodes()

    def initialize(self, bandit):
        # self.d = bandit.dim
        # self.std = 1.0
        # self.b = np.zeros((self.n_users, self.d), dtype= self.dtype)#{i:np.zeros(self.d) for i in range(self.n_users)}
        # self.invV = np.tile(np.eye(self.d, dtype= self.dtype), reps= (self.n_users,1,1))#{i:np.eye(self.d) for i in range(self.n_users)}
        # self.theta = np.zeros_like(self.b) #{i:np.zeros(self.d) for i in range(self.n_users)}
        super(BaseCLUB, self).initialize(bandit)
        self.cluster_inds = np.zeros(self.n_users, dtype= 'int64')

        self.V = self.l * np.tile(np.eye(self.d, dtype= self.dtype), reps= (self.n_users,1,1))#{i:np.eye(self.d) for i in range(self.n_users)}
        
        self.N = np.zeros(self.n_users)
        
        return self

    def play_arm(self, cxts, u, t):
        cluster = self.clusters[self.cluster_inds[u]]
        ucb_inds = cxts @ cluster.theta
        if self.use_ucb:
            beta = self.compute_coef(cluster.N, t)
            ucb_inds += beta * np.sum((cxts @ cluster.invV) * cxts, axis = 1)
        return np.argmax(ucb_inds)

    def get_theta(self):
        return self.theta_vec.reshape((self.n_users, self.d))


    def update(self, x, y, u, t):
        # update the inverse from Equation (10)
        
        # self.b[u] += y * x
        # self.invV[u] -= utils.rank1_update(self.invV[u], x)
        # self.theta[u] = self.invV[u] @ self.b[u]
        super(BaseCLUB, self).update(x, y, u ,t)
        self.V[u] += np.outer(x, x)

        self.N[u] += 1
        

        c = self.cluster_inds[u]
        self.clusters[c].V += np.outer(x, x)
        self.clusters[c].b += y * x
        self.clusters[c].N += 1

        self.clusters[c].invV -= utils.rank1_update(self.clusters[c].invV, x)
        self.clusters[c].theta = self.clusters[c].invV @ self.clusters[c].b
    
        return self

    @staticmethod
    def factor(T):
        return sqrt((1 + log(1 + T)) / (1 + T))
    
    # def _if_split(self, theta, N1, N2, alpha= 1):
    #     rhs = alpha * (self.factor(N1) + self.factor(N2))
    #     return np.linalg.norm(theta) > rhs
    

class CLUB(BaseCLUB):

    def initialize(self, bandit):
        super(CLUB, self).initialize(bandit)
        self.clusters = {0:Cluster(users=range(self.n_users), V= self.l*np.eye(self.d), 
                             b=np.zeros(self.d), N=0)}
        return self

    def update(self, x, y, u, t):
        super(CLUB, self).update(x, y, u, t)

        self.update_clusters_func(u)
    
        return self
    
    def update_clusters_func(self, u):
        c = self.cluster_inds[u]
        update_clusters = False

        A = [a for a in self.G.neighbors(u)]
        for v in A:
            if self.N[u] and self.N[v] and self._if_split(self.theta[u] - self.theta[v], self.N[u], self.N[v]):
                self.G.remove_edge(u, v)
                update_clusters = True

        if update_clusters:
            C = nx.node_connected_component(self.G, u)
            if len(C) < len(self.clusters[c].users):
                remain_users = set(self.clusters[c].users)
                C_arr = np.array(list(C))
                self.clusters[c] = Cluster(list(C), 
                                            V= np.sum(self.V[C_arr]-np.tile(np.eye(self.d), reps= (len(C),1,1)), axis= 0) + np.eye(self.d), 
                                            b= np.sum(self.b[C_arr], axis= 0),
                                            N= np.sum(self.N[C_arr]))

                remain_users = remain_users - set(C)
                c = max(self.clusters) + 1
                while len(remain_users) > 0:
                    v = np.random.choice(list(remain_users))
                    C = nx.node_connected_component(self.G, v)
                    C_arr = np.array(list(C))
                    self.clusters[c] = Cluster(list(C), 
                                            V= np.sum(self.V[C_arr]-np.tile(np.eye(self.d), reps= (len(C),1,1)), axis= 0) + np.eye(self.d), 
                                            b= np.sum(self.b[C_arr], axis= 0),
                                            N= np.sum(self.N[C_arr]))
                    # for j in C:
                    self.cluster_inds[C_arr] = c

                    c += 1
                    remain_users = remain_users - set(C)
        return self

    def _if_split(self, theta, N1, N2, alpha= 1):
        thd = alpha * (CLUB.factor(N1) + CLUB.factor(N2))
        return np.linalg.norm(theta) > thd
    
class SCLUB(BaseCLUB):

    def __init__(self, horizon, *args, **kwargs):
        super(SCLUB, self).__init__(*args, **kwargs)
        self.horizon = horizon
        # self.num_stages = num_stages
        # self.alpha = 4 * np.sqrt(d)
        # self.alpha_p = np.sqrt(4) # 2
        # self.num_clusters = np.ones(self.T)

    def initialize(self, bandit):
        super(SCLUB, self).initialize(bandit)
        # print(self.cluster_inds)
        self.clusters = {0:Cluster(users= list(range(self.n_users)), V= self.l*np.eye(self.d), 
                             b=np.zeros(self.d), N=0, checks = {i: False for i in range(self.n_users)})}
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after INITIALIZE")
        
        return self

    def play_arm(self, cxts, u, t):
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated at start of PLAY")
        self._init_each_stage()
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after INIT_EACH_STAGE")
        # try:
        return super(SCLUB,self).play_arm(cxts, u, t + self.horizon - 1)
        # except KeyError:
        #     return None
    
    def update(self, x, y, u, t):
        super(SCLUB, self).update(x, y, u, t)
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after UPDATE")
        tau = t + self.horizon -1
        self.split(u, tau)
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after SPLIT")
        self.merge(tau)
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after MERGE")
        return self
    

    def _init_each_stage(self):
        for c in self.clusters:
            self.clusters[c].checks = {i:False for i in self.clusters[c].users}
            self.clusters[c].checked = False
    
    def _split_or_merge(self, theta, N1, N2, split=True):
        # alpha = 2 * np.sqrt(2 * self.d)
        # alpha = 1
        if split:
            return np.linalg.norm(theta) >  (self.factor(N1) + self.factor(N2))
        else:
            return np.linalg.norm(theta) <  (self.factor(N1) + self.factor(N2)) / 2

    def _cluster_avg_freq(self, c, t):
        return self.clusters[c].N / (len(self.clusters[c].users) * t)

    def _split_or_merge_p(self, p1, p2, t, split=True):
        if split:
            return np.abs(p1-p2) > sqrt(2) * self.factor(t)
        else:
            return np.abs(p1-p2) < sqrt(2) * self.factor(t) / 2
    
    def _find_available_index(self):
                cmax = max(self.clusters)
                for c1 in range(cmax + 1):
                    if c1 not in self.clusters:
                        return c1
                return cmax + 1

    def split(self, i, t):
        c = self.cluster_inds[i]
        cluster = self.clusters[c]
        cluster.update_check(i)

        condition1 = self._split_or_merge_p(self.N[i]/(t+1), self._cluster_avg_freq(c, t+1), t+1, split=True)
        condition2 = self._split_or_merge(self.theta[i] - cluster.theta, self.N[i], cluster.N, split=True)
        if condition1 or condition2: 
            cnew = self._find_available_index()
            self.clusters[cnew] = Cluster(users=[i],V=self.V[i],b=self.b[i],N=self.N[i],checks={i:True})
            self.cluster_inds[i] = cnew

            cluster.users.remove(i)
            cluster.V = cluster.V - self.V[i] + np.eye(self.d)
            cluster.b = cluster.b - self.b[i]
            cluster.N = cluster.N - self.N[i]
            del cluster.checks[i]

    def merge(self, t):
        cmax = max(self.clusters)

        for c1 in range(cmax + 1):
            if c1 not in self.clusters or self.clusters[c1].checked == False:
                continue

            for c2 in range(c1 + 1, cmax + 1):
                if c2 not in self.clusters or self.clusters[c2].checked == False:
                    continue

                if self._split_or_merge(self.clusters[c1].theta - self.clusters[c2].theta, self.clusters[c1].N, self.clusters[c2].N, split=False) and self._split_or_merge_p(self._cluster_avg_freq(c1, t+1), self._cluster_avg_freq(c2, t+1), t+1, split=False):

                    for i in self.clusters[c2].users:
                        self.cluster_inds[i] = c1

                    self.clusters[c1].users = self.clusters[c1].users + self.clusters[c2].users
                    self.clusters[c1].V = self.clusters[c1].V + self.clusters[c2].V - np.eye(self.d)
                    self.clusters[c1].b = self.clusters[c1].b + self.clusters[c2].b
                    self.clusters[c1].N = self.clusters[c1].N + self.clusters[c2].N
                    self.clusters[c1].checks = {**self.clusters[c1].checks, **self.clusters[c2].checks}

                    del self.clusters[c2]


class NetworkLassoAgent():
    """ Class of the network lasso policy
    """

    def __init__(self, adjacency, solver= nl.SolverPrimalDual(), first_term= nl.LstsqPrimalDual(), 
              l_net=1.0, beta=2.0,
              weighting= None, dtype='float32', K_norm='F'):
        """ initialization class

        Parameters
        ----------
        theta : array_like
            Estimated preference matrix
        adjacency : array_like
            graph adjacency matrix
        solver : _type_
            instance of the solver class used to solve the optimization problem at each time step 
        l_reg : float, optional
            regularization coefficient of the squared Frobenius norm, by default 0.0
        l_net : float, optional
            regularization coefficient of the total variation penalty, by default 1.0 # TODO set to theory inspired default value
        """
        # if edge_weighting == "random walk":
        #     adjacency = adjacency/np.sum(adjacency, axis=0)
        #     self.adjacency = 0.5*(adjacency + adjacency.T)
        if weighting == "degree":
            weights = 1/np.sqrt(np.sum(adjacency, axis=0))
            self.adjacency = weights * adjacency * weights[:,None]
        
        if weighting == "centrality":
            laplacian = np.diag(adjacency.sum(0)) - adjacency
            weights = np.sqrt(np.diag(np.linalg.pinv(laplacian)))
            weights /= max(weights)
            self.adjacency = weights * adjacency * weights[:,None]
        
        if weighting is None:
            self.adjacency = adjacency.astype(dtype)

        self.l_net = l_net

       
        self.solver = solver

        self.first_term = first_term
        self.dtype = dtype
        self.beta = beta
        self.K_norm = K_norm
        self.counts = np.zeros(self.adjacency.shape[0], dtype= 'int32')

    def play_arm(self, cxts, u, t):
        return np.argmax(cxts @ self.theta[u])

    def get_theta(self):
        return self.theta

    def initialize(self, bandit):
        """ get the needed information from the bandit

        Parameters
        ----------
        bandit : instance of the MultiTaskBandit class
            the bandit with which the agent will interact

        Returns
        -------
        _type_
            _description_
        """
        self.d = bandit.dim
        self.std = bandit.std
        
        self.theta = np.zeros(
            (self.adjacency.shape[0], self.d), dtype=self.dtype)
        
        self.solver.get_external_params(self.adjacency, self.theta)
        self.first_term.get_external_params(self.solver)
        self.A = np.zeros(
            (self.adjacency.shape[0], self.d, self.d), dtype=self.dtype)
        # squared Frobenius norm of X
        self.X_norm_F_sqr = np.zeros(self.adjacency.shape[0])
        self.B_pinv = utils.fast_pinv(self.solver.B.toarray())
        self.normB_pinv = np.max(np.abs(self.B_pinv))

        
        self.u_estim = np.zeros(
            (self.solver.B.shape[0], self.d), dtype=self.dtype)
        return self

    def update(self, x, y, u, t):
        self.counts[u] += 1
        # TODO solver update parameters
        self.first_term.update(u, x, y)
        # self.l_net = sqrt(t)
        self.A[u] += np.outer(x, x)
        self.X_norm_F_sqr[u] += np.dot(x, x)
        # spectral norm per user
        
        l_net_t = self._update_alpha(t)

        self.theta, self.u_estim = self.solver.solve(
            first_term= self.first_term, l_net=l_net_t, 
            theta0= self.theta, U0= self.u_estim)
        return self
    
    def _update_alpha(self,t):

        if self.K_norm == "F":
            A_norm_2 = np.linalg.norm(self.A, ord=2, axis=(1, 2))
            # sum of square F norms
            A_norm_F_sqr = np.einsum("uij,uij -> u", self.A, self.A)
            loginvdelta = self.beta*log(t)
            l_net_t = utils.sum_Hsu(self.X_norm_F_sqr.sum(), A_norm_F_sqr.sum(),
                                    np.max(A_norm_2), loginvdelta)
        
        if self.K_norm == "F_bound":
            loginvdelta = self.beta*log(t)
            l_net_t = t + 2*np.linalg.norm(self.counts)*sqrt(loginvdelta) 
            + 2*max(self.counts)*loginvdelta

            # bound on norm ||K||_{2,1}^2$
            # if self.K_norm == "21":
            #     loginvdelta = self.beta*log(t) + log(self.adjacency.shape[0])
            #     l_net_t = utils.sum_Hsu(self.X_norm_F_sqr, A_norm_F_sqr,
                                        # A_norm_2, loginvdelta).max()
                
            # if self.K_norm == "21_bound":
            #     loginvdelta = self.beta*log(t) + log(self.adjacency.shape[0])
            #     l_net_t = utils.sum_Hsu(self.X_norm_F_sqr, A_norm_F_sqr,
            #                             A_norm_2, loginvdelta).max()
                
            # if self.K_norm == "2inf":
            #     loginvdelta = self.beta*log(t) + log(self.adjacency.shape[0])
            #     l_net_t = utils.sum_Hsu(self.X_norm_F_sqr, A_norm_F_sqr,
            #                             A_norm_2, loginvdelta).sum()

            # if self.K_norm == "graph":
            #     A_all = np.sum(self.A, axis=0)
            #     A_all_norm_2 = np.linalg.norm(A_all, ord=2)
            #     A_all_norm_F_sqr = np.einsum("ij,ij", A_all, A_all)
            #     X_all_norm_F_sqr = np.sum(self.X_norm_F_sqr)

            #     loginvdelta = self.beta*log(t)
            #     l_net_1 = utils.sum_Hsu(X_all_norm_F_sqr, A_all_norm_F_sqr,
            #                             A_all_norm_2, loginvdelta)
            #     l_net_1 = sqrt(l_net_1/self.adjacency.shape[0])

            #     loginvdelta = self.beta*log(t) + log(self.B_pinv.shape[0])
            #     l_net_2 = utils.sum_Hsu(X_all_norm_F_sqr, A_all_norm_F_sqr,
            #                             A_all_norm_2, loginvdelta)
            #     l_net_2 = self.normB_pinv * sqrt(l_net_2)

            #     l_net_t = max(l_net_1, l_net_2)
            
            # if self.K_norm == "graph_bound":
            #     A_all = np.sum(self.A, axis=0)
            #     A_all_norm_2 = np.linalg.norm(A_all, ord=2)
            #     A_all_norm_F_sqr = np.einsum("ij,ij", A_all, A_all)
            #     X_all_norm_F_sqr = np.sum(self.X_norm_F_sqr)

            #     loginvdelta = self.beta*log(t)
            #     l_net_1 = utils.sum_Hsu(X_all_norm_F_sqr, A_all_norm_F_sqr,
            #                             A_all_norm_2, loginvdelta)
            #     l_net_1 = sqrt(l_net_1/self.adjacency.shape[0])

            #     loginvdelta = self.beta*log(t) + log(self.B_pinv.shape[0])
            #     l_net_2 = utils.sum_Hsu(X_all_norm_F_sqr, A_all_norm_F_sqr,
            #                             A_all_norm_2, loginvdelta)
            #     l_net_2 = self.normB_pinv * sqrt(l_net_2)

            #     l_net_t = max(l_net_1, l_net_2)

        return self.l_net * self.std * sqrt(l_net_t)
