# %%
import utils
from sklearn.linear_model import Lasso as LassoSklearn
from math import log, sqrt
import numpy as np
import network_lasso as nl
import networkx as nx
from scipy.linalg import sqrtm
from scipy.special import log1p
from sklearn.linear_model import Lasso as LassoSklearn
from tqdm import tqdm
import cvxpy as cp
# from pyproximal.optimization.primal import ProximalGradient


class LinUcbMulti():
    """
        LinUCB with different instances. Both Itl and cluster oracle inherit form this class
    """
    def __init__(self, l_reg= 1.0, use_ucb = True, dtype= 'float32'):
        self.l = l_reg
        self.use_ucb = use_ucb
        self.dtype = dtype
        self.n_tasks = None

    def initialize(self, bandit):
        self.std = bandit.std
        self.d = bandit.dim
        # self. n_u = bandit.n_users
        if self.n_tasks is not None:
            self.invV = 1/self.l * np.tile(np.eye(bandit.dim), (self.n_tasks,1,1))
            self.b = np.zeros((self.n_tasks, bandit.dim))
            self.theta = np.zeros((self.n_tasks, bandit.dim))
            self.counts = np.zeros(self.n_tasks, dtype= "int64")
        
        return self

    def play_arm(self, cxts, u, t):
        # TODO make it data dependent
        # TODO take sigma into account
        ucb_inds = cxts @ self.theta[u]
        if self.use_ucb:
            coef = self.compute_coef(self.counts[u], t)
            ucb_inds += self.std * coef * np.sum((cxts @ self.invV[u])*cxts, axis= 1)
        return np.argmax(ucb_inds)


    def update(self, x, y, u, t):
        self.invV[u] -= utils.rank1_update(self.invV[u], x)
        self.b[u] += y * x
        self.theta[u] = self.invV[u] @ self.b[u]
        self.counts[u] += 1
        return self
    
    def compute_coef(self, count, t):
        return sqrt(4*log(t) + self.d*log(1+count/(self.d*self.l))) + sqrt(self.l)


class LinUcbItl(LinUcbMulti):
    """
        LinUCB Independent Task Learning, i.e. each agent learns its own model without any
        connection between them
    """
    def __init__(self, l_reg = 1.0, use_ucb= True):
        super(LinUcbItl, self).__init__(l_reg, use_ucb)

    def initialize(self, bandit):
        self.n_tasks = bandit.n_users
        super(LinUcbItl, self).initialize(bandit)
        return self


class LinUcbClusterOracle(LinUcbMulti):
    def __init__(self, cluster_inds, *args, **kwargs):
        super(LinUcbClusterOracle, self).__init__(*args, **kwargs)
        self.cluster_inds = cluster_inds
        self.n_tasks = max(cluster_inds)+1

    def play_arm(self, cxts, u, t):
        c = self.cluster_inds[u]
        return super(LinUcbClusterOracle,self).play_arm(cxts, c, t)
        # return np.argmax(cxts @ self.theta[c] + radius)

    def update(self, x, y, u, t):
        c = self.cluster_inds[u]
        return super(LinUcbClusterOracle,self).update(x,y,c,t)
  

class GOBLin():
    """
        GOBLin as described in the original paper
    """
    def __init__(self, adjacency, l_reg= 1.0, delta= 0.01, use_ucb = True, dtype= 'float32'):
        self.delta = delta
        self.laplacian = np.diag(adjacency.sum(0))-adjacency
        self.use_ucb = use_ucb
        self.dtype = dtype
        self.l = l_reg
    
    def initialize(self, bandit):
        self.std = bandit.std
        self.invV = 1/self.l*np.eye(bandit.dim * bandit.n_users, dtype= self.dtype)
        self.logdetV = log(self.l)
        self.b = np.zeros(bandit.dim * bandit.n_users, dtype= self.dtype)
        self.theta = self.invV @ self.b 
        self.inv_sqrt_lap_id = np.linalg.inv(sqrtm(self.laplacian + np.eye(bandit.n_users, dtype= self.dtype)))
        return self
    
    def play_arm(self, cxts, u, t): 
        cxts_aug = np.kron(self.inv_sqrt_lap_id[u], cxts)
        indices = cxts_aug @ self.theta
        if self.use_ucb:
            radius = np.sum((cxts_aug @ self.invV)*cxts_aug, axis=1)
            # indices +=  (1 + self.std * sqrt(self.logdetV + log(1/self.delta)))*radius
            indices +=  (1 + self.std * sqrt(self.logdetV + 4*log(t)))*radius
        return np.argmax(indices)

    def update(self, x, y, u, t):
        x_aug = np.kron(self.inv_sqrt_lap_id[u], x)
        self.b += y * x_aug
        invV_x_aug = self.invV @ x_aug
        self.logdetV += log1p(x_aug @ invV_x_aug)
        self.invV -= np.outer(invV_x_aug, invV_x_aug)/(1 + x_aug @ invV_x_aug)
        self.theta = self.invV @ self.b 
        return self
        

class GraphLinUCB():
    # TODO store estimated Theta (with Laplacian) and use its rows for beta

    def __init__(self, laplacian, alpha=1.0, delta=0.01, state= 1, use_ucb = True, symmetrize= False, dtype= 'float32'):
        """_summary_

        Parameters
        ----------
        laplacian : _type_
            _description_
        a : float, optional
            regularization parameter, by default 1.0
        delta : float, optional
            high probability parameter, by default 0.01
        symmetrize : bool, optional
            whether the symmetrize the random walk laplacian, by default True
        """
        self.dtype= dtype
        self.delta = delta
        self.alpha = alpha
        self.lap = laplacian.astype(self.dtype)
        if symmetrize:
            self.lap = 0.5*(laplacian + laplacian.T)
        self.state = state
        self.user_num = len(self.lap)
        self.L = self.lap + 0.01*np.eye(self.user_num, dtype = self.dtype)
        self.use_ucb = use_ucb
        

    def initialize(self, bandit):
        self.d = bandit.dim
        self.theta = np.zeros((self.user_num, self.d), dtype= self.dtype)

        self.A = np.kron(self.L, np.eye(self.d))
        self.A_inv = np.linalg.inv(self.A)  # XXX pinv --> inv
        self.XX = np.zeros(
            (self.user_num*self.d, self.user_num*self.d), dtype= self.dtype)
        self.sigma = bandit.std
        self.cov = self.alpha*self.A
        self.bias = np.zeros((self.user_num*self.d), dtype= self.dtype)
        self.beta_list = []
        self.user_v = self.alpha * \
            np.tile(np.eye(self.d, dtype= self.dtype), (self.user_num, 1, 1))
        self.user_avg = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_ls = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_ridge = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_xx = 0.1 * \
            np.tile(np.eye(self.d, dtype= self.dtype), reps=(self.user_num, 1, 1))
        self.user_bias = np.zeros((self.user_num, self.d), dtype= self.dtype)
        self.user_counter = np.zeros(self.user_num, dtype= self.dtype)
        self.user_h= np.zeros((self.user_num, self.d, self.d), dtype= self.dtype)
        self.users_but_u = np.array([[i for i in range(self.user_num) if i != j]
                                     for j in range(self.user_num)], dtype="int32")

    def play_arm(self, cxts, u, t):
        self.update_beta(u) 
        X_norm_sqr = np.einsum("ij,ji->i", cxts,\
                         np.linalg.solve(self.user_h[u], cxts.T))
        index = np.dot(cxts, self.theta[u])
        if self.use_ucb: 
            index = index + self.beta * np.sqrt(X_norm_sqr)
        return np.argmax(index)

    def get_theta(self):
        return self.theta

    def update(self, x, y, u, t):
        f = np.zeros((self.user_num*self.d))
        f[u*self.d:(u+1)*self.d]=x
        self.user_v[u]+=np.outer(x, x)
        self.user_xx[u]+=np.outer(x, x)
        self.user_bias[u]+=y*x
        self.user_ls[u]= np.linalg.solve(self.user_xx[u], self.user_bias[u]) # XXX inv --> solve


        self.XX+=np.outer(f, f)
        self.bias+=y*f

        self.cov+=np.outer(f, f)
        self.theta=np.linalg.solve(self.cov, self.bias).reshape((self.user_num, self.d)) #XXX inv --> solve

        self.user_avg = self.L @ self.user_ls # XXX vectorized
        return self

    def update_beta(self, u):
        u_range = self.users_but_u[u]
        sum_A = np.sum(np.einsum("k,kij-> kij", self.L[u,u_range]**2,
                                       np.linalg.inv(self.user_xx[u_range])), axis= 0)
        self.user_h[u]=self.user_xx[u]+self.alpha**2*sum_A\
                        +2*self.alpha*self.L[u, u]*np.eye(self.d)
        a=np.linalg.det(self.user_v[u])**(1/2)
        b=np.linalg.det(self.alpha*np.eye(self.d))**(-1/2)
        d=self.sigma*np.sqrt(2*np.log(a*b/self.delta))
        if self.user_counter[u]==0:
            if self.state==1:
                self.user_avg[u]=np.dot(self.user_ls.T, self.L[u])
            else:
                self.user_avg[u]=np.dot(self.true_theta.T, self.L[u])
        else:
            if self.state==1:
                self.user_avg[u]=np.dot(self.user_ls.T, self.L[u])
            else:
                self.user_avg[u]=np.dot(self.true_theta.T, self.L[u])

        c=np.sqrt(self.alpha)*np.linalg.norm(self.user_avg[u])
        self.beta=c+d

        return self
    
class Cluster:
    def __init__(self, users, V, b, N, checks= {}):
        self.users = list(users) # a list/array of users
        self.V = V # this is the matrix M in CLUB
        self.b = b
        self.N = N
        self.invV = np.linalg.inv(self.V)
        self.theta = np.matmul(self.invV, self.b) # this is the vector w in CLUB

        # only used by SCLUB
        self.checks = checks
        # FIXME
        self.checked = len(self.users) == sum(self.checks.values())

    def update_check(self, u):
        self.checks[u] = True
        self.checked = len(self.users) == sum(self.checks.values())

class BaseCLUB(LinUcbItl):
    # TODO add sigma parameter

    def __init__(self, G, a=1.0, a2=0.01, *args, **kwargs):
        super(BaseCLUB, self).__init__(*args, **kwargs)
        self.G = G
        self.a = a
        self.n_u = G.number_of_nodes()

    def initialize(self, bandit):
        # self.d = bandit.dim
        # self.std = 1.0
        # self.b = np.zeros((self.n_u, self.d), dtype= self.dtype)#{i:np.zeros(self.d) for i in range(self.n_u)}
        # self.invV = np.tile(np.eye(self.d, dtype= self.dtype), reps= (self.n_u,1,1))#{i:np.eye(self.d) for i in range(self.n_u)}
        # self.theta = np.zeros_like(self.b) #{i:np.zeros(self.d) for i in range(self.n_u)}
        super(BaseCLUB, self).initialize(bandit)
        self.cluster_inds = np.zeros(self.n_u, dtype= 'int64')

        self.V = self.l * np.tile(np.eye(self.d, dtype= self.dtype), reps= (self.n_u,1,1))#{i:np.eye(self.d) for i in range(self.n_u)}
        
        self.N = np.zeros(self.n_u)
        
        return self

    def play_arm(self, cxts, u, t):
        cluster = self.clusters[self.cluster_inds[u]]
        ucb_inds = cxts @ cluster.theta
        if self.use_ucb:
            beta = self.compute_coef(cluster.N, t)
            ucb_inds += beta * np.sum((cxts @ cluster.invV) * cxts, axis = 1)
        return np.argmax(ucb_inds)

    def get_theta(self):
        return self.theta_vec.reshape((self.n_u, self.d))


    def update(self, x, y, u, t):
        # update the inverse from Equation (10)
        
        # self.b[u] += y * x
        # self.invV[u] -= utils.rank1_update(self.invV[u], x)
        # self.theta[u] = self.invV[u] @ self.b[u]
        super(BaseCLUB, self).update(x, y, u ,t)
        self.V[u] += np.outer(x, x)

        self.N[u] += 1
        

        c = self.cluster_inds[u]
        
        self.clusters[c].V += np.outer(x, x)
        self.clusters[c].invV = np.linalg.inv(self.clusters[c].V)
        self.clusters[c].b += y * x
        self.clusters[c].N += 1

        
        self.clusters[c].theta = self.clusters[c].invV @ self.clusters[c].b
    
        return self

    @staticmethod
    def factor(T):
        return sqrt((1 + log(1 + T)) / (1 + T))
    
    # def _if_split(self, theta, N1, N2, alpha= 1):
    #     rhs = alpha * (self.factor(N1) + self.factor(N2))
    #     return np.linalg.norm(theta) > rhs
    

class CLUB(BaseCLUB):

    def initialize(self, bandit):
        super(CLUB, self).initialize(bandit)
        self.clusters = {0:Cluster(users=range(self.n_u), V= self.l*np.eye(self.d), 
                             b=np.zeros(self.d), N=0)}
        return self

    def update(self, x, y, u, t):
        super(CLUB, self).update(x, y, u, t)

        self.update_clusters_func(u)
    
        return self
    
    def update_clusters_func(self, u):
        c = self.cluster_inds[u]
        update_clusters = False

        A = [a for a in self.G.neighbors(u)]
        for v in A:
            if self.N[u] and self.N[v] and self._if_split(self.theta[u] - self.theta[v], self.N[u], self.N[v]):
                self.G.remove_edge(u, v)
                update_clusters = True

        if update_clusters:
            C = nx.node_connected_component(self.G, u)
            if len(C) < len(self.clusters[c].users):
                remain_users = set(self.clusters[c].users)
                C_arr = np.array(list(C))
                self.clusters[c] = Cluster(list(C), 
                                            V= np.sum(self.V[C_arr]-np.tile(np.eye(self.d), reps= (len(C),1,1)), axis= 0) + np.eye(self.d), 
                                            b= np.sum(self.b[C_arr], axis= 0),
                                            N= np.sum(self.N[C_arr]))

                remain_users = remain_users - set(C)
                c = max(self.clusters) + 1
                while len(remain_users) > 0:
                    v = np.random.choice(list(remain_users))
                    C = nx.node_connected_component(self.G, v)
                    C_arr = np.array(list(C))
                    self.clusters[c] = Cluster(list(C), 
                                            V= np.sum(self.V[C_arr]-np.tile(np.eye(self.d), reps= (len(C),1,1)), axis= 0) + np.eye(self.d), 
                                            b= np.sum(self.b[C_arr], axis= 0),
                                            N= np.sum(self.N[C_arr]))
                    # for j in C:
                    self.cluster_inds[C_arr] = c

                    c += 1
                    remain_users = remain_users - set(C)
        return self

    def _if_split(self, theta, N1, N2, alpha= 1):
        thd = alpha * (CLUB.factor(N1) + CLUB.factor(N2))
        return np.linalg.norm(theta) > thd
    
class SCLUB(BaseCLUB):

    def __init__(self, horizon, *args, **kwargs):
        super(SCLUB, self).__init__(*args, **kwargs)
        self.horizon = horizon
        # self.num_stages = num_stages
        # self.alpha = 4 * np.sqrt(d)
        # self.alpha_p = np.sqrt(4) # 2
        # self.num_clusters = np.ones(self.T)

    def initialize(self, bandit):
        super(SCLUB, self).initialize(bandit)
        # print(self.cluster_inds)
        self.clusters = {0:Cluster(users= list(range(self.n_u)), V= self.l*np.eye(self.d), 
                             b=np.zeros(self.d), N=0, checks = {i: False for i in range(self.n_u)})}
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after INITIALIZE")
        
        return self

    def play_arm(self, cxts, u, t):
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated at start of PLAY")
        self._init_each_stage()
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after INIT_EACH_STAGE")
        # try:
        return super(SCLUB,self).play_arm(cxts, u, t + self.horizon - 1)
        # except KeyError:
        #     return None
    
    def update(self, x, y, u, t):
        super(SCLUB, self).update(x, y, u, t)
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after UPDATE")
        tau = t + self.horizon -1
        self.split(u, tau)
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after SPLIT")
        self.merge(tau)
        # if max(self.cluster_inds) > max(self.clusters.keys()):
        #     raise KeyError("clusters NOT updated after MERGE")
        return self
    

    def _init_each_stage(self):
        for c in self.clusters:
            self.clusters[c].checks = {i:False for i in self.clusters[c].users}
            self.clusters[c].checked = False
    
    def _split_or_merge(self, theta, N1, N2, split=True):
        # alpha = 2 * np.sqrt(2 * self.d)
        # alpha = 1
        if split:
            return np.linalg.norm(theta) >  (self.factor(N1) + self.factor(N2))
        else:
            return np.linalg.norm(theta) <  (self.factor(N1) + self.factor(N2)) / 2

    def _cluster_avg_freq(self, c, t):
        return self.clusters[c].N / (len(self.clusters[c].users) * t)

    def _split_or_merge_p(self, p1, p2, t, split=True):
        if split:
            return np.abs(p1-p2) > sqrt(2) * self.factor(t)
        else:
            return np.abs(p1-p2) < sqrt(2) * self.factor(t) / 2
    
    def _find_available_index(self):
        cmax = max(self.clusters)
        for c1 in range(cmax + 1):
            if c1 not in self.clusters:
                return c1
        return cmax + 1

    def split(self, i, t):
        c = self.cluster_inds[i]
        cluster = self.clusters[c]
        cluster.update_check(i)

        condition1 = self._split_or_merge_p(self.N[i]/(t+1), self._cluster_avg_freq(c, t+1), t+1, split=True)
        condition2 = self._split_or_merge(self.theta[i] - cluster.theta, self.N[i], cluster.N, split=True)
        if condition1 or condition2: 
            cnew = self._find_available_index()
            self.clusters[cnew] = Cluster(users=[i],V=self.V[i],b=self.b[i],N=self.N[i],checks={i:True})
            self.cluster_inds[i] = cnew

            cluster.users.remove(i)
            cluster.V = cluster.V - self.V[i] + np.eye(self.d)
            cluster.b = cluster.b - self.b[i]
            cluster.N = cluster.N - self.N[i]
            del cluster.checks[i]

    def merge(self, t):
        cmax = max(self.clusters)

        for c1 in range(cmax + 1):
            if c1 not in self.clusters or self.clusters[c1].checked == False:
                continue

            for c2 in range(c1 + 1, cmax + 1):
                if c2 not in self.clusters or self.clusters[c2].checked == False:
                    continue

                if self._split_or_merge(self.clusters[c1].theta - self.clusters[c2].theta, self.clusters[c1].N, self.clusters[c2].N, split=False) and self._split_or_merge_p(self._cluster_avg_freq(c1, t+1), self._cluster_avg_freq(c2, t+1), t+1, split=False):

                    for i in self.clusters[c2].users:
                        self.cluster_inds[i] = c1

                    self.clusters[c1].users = self.clusters[c1].users + self.clusters[c2].users
                    self.clusters[c1].V = self.clusters[c1].V + self.clusters[c2].V - np.eye(self.d)
                    self.clusters[c1].b = self.clusters[c1].b + self.clusters[c2].b
                    self.clusters[c1].N = self.clusters[c1].N + self.clusters[c2].N
                    self.clusters[c1].checks = {**self.clusters[c1].checks, **self.clusters[c2].checks}

                    del self.clusters[c2]


class ClusterLOCB:
    def __init__(self, users, V, b, N):
        self.users = set(users) # a list/array of users
        self.V = V
        self.b = b
        self.N = N
        self.invV = np.linalg.inv(self.V)
        self.theta = self.invV @ self.b

class LOCB:

    def __init__(self, gamma= 0.2, num_seeds= 20, delta= 0.1, detect_cluster= False):
        self.num_seeds = num_seeds
        self.gamma = gamma
        
        self.selected_cluster = 0 
        self.delta = delta
        self.if_d = detect_cluster
        self.fin = 0
    
    def initialize(self, bandit):
        self.d = bandit.dim
        self.n = bandit.n_users
        self.V = np.tile(np.eye(bandit.dim), (bandit.n_users,1,1))
        self.b = np.zeros((bandit.n_users, bandit.dim))
        self.invV = np.tile(np.eye(bandit.dim), (bandit.n_users,1,1))
        self.theta = np.zeros((bandit.n_users, bandit.dim))
        self.users = range(bandit.n_users)
        self.N = np.zeros(bandit.n_users)
        self.cluster_inds = {i:[] for i in range(self.n)}

        self.seeds = np.random.choice(self.users, self.num_seeds)
        self.seed_state = {}
        for seed in self.seeds:
            self.seed_state[seed] = 0
        self.clusters = {}
        for seed in self.seeds: 
            self.clusters[seed] = ClusterLOCB(users=self.users, V=np.eye(self.d), b=np.zeros(self.d), N=1)

        for i in self.users:
            for seed in self.seeds:
                if i in self.clusters[seed].users:
                    self.cluster_inds[i].append(seed)
       
        return self
        
    def _beta(self, N, t):
        return np.sqrt(self.d * np.log(1 + N / self.d) + 4 * np.log(t) + np.log(2)) + 1

    def _select_item_ucb(self, S, Sinv, theta, items, N, t):
        ucbs = np.dot(items, theta) + self._beta(N, t) * (np.matmul(items, Sinv) * items).sum(axis = 1)
        res = max(ucbs)
        it = np.argmax(ucbs)
        return (res, it)

    def play_arm(self, items, u, t):
        cls = self.cluster_inds[u]
        if (len(cls)>0) and (t <40000):
            res = []
            for c in cls:
                cluster = self.clusters[c]
                res_sin = self._select_item_ucb(cluster.V,cluster.invV, cluster.theta, items, cluster.N, t)
                res.append(res_sin)
            best_cluster = max(res)
            return best_cluster[1]
        else:
            no_cluster = self._select_item_ucb(self.V[u], self.invV[u], self.theta[u], items, self.N[u], t)
            return no_cluster[1]
    
    def update(self, x,y,u,t):
        self.invV[u] -= utils.rank1_update(self.invV[u], x)
        self.V[u] += np.outer(x, x)
        self.b[u] += y * x
        self.theta[u] = self.invV[u] @ self.b[u]
        self.N[u] += 1

        for c in self.cluster_inds[u]:
            # self.clusters[c].invV -= rk_1_update
            self.clusters[c].V += np.outer(x, x)
            self.clusters[c].invV = np.linalg.inv(self.clusters[c].V)
            self.clusters[c].b += y * x
            self.clusters[c].N += 1
            self.clusters[c].theta = self.clusters[c].invV @ self.clusters[c].b
        
        self.update_clusters(u,t)

        return self
    
    @staticmethod    
    def _factT(m):
        # if self.if_d:
        #     delta = self.delta / self.n
        #     nu = np.sqrt(2*self.d*np.log(1 + t) + 2*np.log(2/delta)) +1
        #     de = np.sqrt(1+m/4)*np.power(self.n, 1/3)
        #     return nu/de
        # else:
        return np.sqrt((1 + np.log(1 + m)) / (1 + m))     
            

        
    
    def update_clusters(self, i, t):
        
        
        if not self.fin:

              
            for seed in self.seeds:
                if not self.seed_state[seed]:
                    if i in self.clusters[seed].users:
                        diff = self.theta[i] - self.theta[seed]
                        if np.linalg.norm(diff) > self._factT(self.N[i]) + self._factT(self.N[seed]):
                            self.clusters[seed].users.remove(i)
                            self.cluster_inds[i].remove(seed)                            
                            self.clusters[seed].V = self.clusters[seed].V - self.V[i] + np.eye(self.d)
                            self.clusters[seed].b = self.clusters[seed].b - self.b[i]
                            self.clusters[seed].N = self.clusters[seed].N - self.N[i]
                            
                    else:
                        diff = self.theta[i] - self.theta[seed]
                        if np.linalg.norm(diff) < self._factT(self.N[i]) + self._factT(self.N[seed]):
                            self.clusters[seed].users.add(i)
                            self.cluster_inds[i].append(seed)
                            self.clusters[seed].V = self.clusters[seed].V + self.V[i] - np.eye(self.d)
                            self.clusters[seed].b = self.clusters[seed].b + self.b[i]
                            self.clusters[seed].N = self.clusters[seed].N + self.N[i]
                
                    if self.if_d: thre = self.gamma 
                    else: thre = self.gamma/4
                        
                    if self._factT(self.N[seed]) <= thre:
                        self.seed_state[seed] = 1
                        self.results.append({seed:list(self.clusters[seed].users)}) 

            finished = 1
            for i in self.seed_state.values():
                if i ==0:
                    finished =0
                    
            if finished: 
                if self.if_d:
                    np.save('./results/clusters', self.results)
                    print('Clustering finished! Round:', t)
                    self.stop = 1
                self.fin = 1
        return self


    
class TraceNormAgent():
    """ Class of the network lasso policy
    """

    def __init__(self, l_nuc=1.0, delta= 0.01, dtype= 'float64'):
      
        self.l_nuc = l_nuc
        self.delta= delta
        self.dtype = dtype # CVXPY does not work with single precision

    def play_arm(self, cxts, u, t):
        return np.argmax(cxts @ self.theta[u])

    def get_theta(self):
        return self.theta

    def initialize(self, bandit):
        """ get the needed information from the bandit

        Parameters
        ----------
        bandit : instance of the MultiTaskBandit class
            the bandit with which the agent will interact

        Returns
        -------
        _type_
            _description_
        """
        self.d = bandit.dim
        self.std = bandit.std
        
        self.theta = np.zeros(
            (bandit.n_users, self.d), dtype=self.dtype)
        self. n_u = bandit.n_users

        self.A = np.zeros((self.n_u, self.d, self.d), dtype=self.dtype)
        self.b = np.zeros((self.n_u,self.d), dtype= self.dtype)
        
        return self

    def update(self, x, y, u, t):
        self.A[u] += np.outer(x,x)
        self.b[u] += y*x
        t_eq = t//self.n_u + 1
        self.l_nuc_t = t_eq*self.l_nuc*self.reg_func((self.n_u + self.d)/t_eq, log(2/self.delta)/t_eq)
        lip = 2*np.max(np.linalg.norm(self.A, ord=2, axis= (1,2)))
        self.theta = self.solve_mult_reg_nuc(lip= lip)
        # self.theta = self.solve_mult_reg_nuc(self.A, self.b, self.n_u, self.d, self.l_nuc_t)
        return self
    
    @staticmethod
    def reg_func(x1,x2):
        return max(x1+x2, sqrt(x1)+sqrt(x2))
    
    def set_pb_params(self, A,b, l_nuc_t):
        self.A = A
        self.b = b
        self.n_u, self.d = b.shape
        self.l_nuc_t = l_nuc_t
        return self
    
    def solve_mult_reg_nuc(self, lip, gamma= 1.01, max_iter= 500):
        alpha_prec = 1.0
        Theta_prec = np.zeros((self.n_u, self.d))
        Z = np.zeros_like(Theta_prec)
        
        for i in range(max_iter):
            # Z_grad_step = Z - 1/lip*self.l2sqr_grad(Z)
            # Theta = self.prox_nuc(Z_grad_step, self.l_nuc_t/lip)
            step_gap = np.inf
            
            while step_gap>0:
                Z_grad_step = Z - 1/lip*self.l2sqr_grad(Z)
                Theta = self.prox_nuc(Z_grad_step, self.l_nuc_t/lip)
                step_gap = self.obj_val(Theta) - self.obj_quasi_lin_val(Theta, Z, lip)
                lip *= gamma

            alpha = (1+sqrt(1+4*alpha_prec**2))/2

            Theta_diff = Theta-Theta_prec

            if np.max(np.abs(Theta_diff))<= 1e-8:
                break
        
            Z = Theta + (alpha_prec-1)/alpha*Theta_diff
            
            Theta_prec = Theta
            alpha_prec = alpha

        return Theta
    
    # def solve_mult_reg_nuc_verbose(self, lip=2.0, gamma= 1.001, max_iter= 20000):
    #     alpha_prec = 1.0
    #     Theta_prec = np.zeros((self.n_u, self.d))
    #     Z = np.zeros_like(Theta_prec)
    #     self.costs= []
        
    #     for i in tqdm(range(max_iter)):
    #         # Z_grad_step = Z - 1/lip*self.l2sqr_grad(Z)
    #         # Theta = self.prox_nuc(Z_grad_step, self.l_nuc_t/lip)
    #         step_gap = np.inf
            
    #         while step_gap>0:
    #             print(".",end="")
    #             Z_grad_step = Z - 1/lip*self.l2sqr_grad(Z)
    #             Theta = self.prox_nuc(Z_grad_step, self.l_nuc_t/lip)
    #             step_gap = self.obj_val(Theta) - self.obj_quasi_lin_val(Theta, Z, lip)
    #             lip *= gamma

    #         alpha = (1+sqrt(1+4*alpha_prec**2))/2

    #         Theta_diff = Theta-Theta_prec
    #         if np.max(np.abs(Theta_diff))<= 1e-12:
    #             break
    #         Z = Theta + (alpha_prec-1)/alpha*Theta_diff
            
    #         Theta_prec = Theta
    #         alpha_prec = alpha
    #         self.costs.append(self.obj_val(Theta))
    #     return Theta
            
    # @staticmethod
    # def solve_mult_reg_nuc(A, b, n_u, d, l_nuc):
    #     Theta_cp = cp.Variable((n_u, d))
    #     obj_quad = cp.sum([cp.quad_form(Theta_cp[u], A[u]) for u in range(n_u)])
    #     obj_lin = cp.sum([-2*b[u] @ Theta_cp[u] for u in range(n_u)])
    #     obj_reg = l_nuc * cp.norm(Theta_cp, 'nuc')
    #     prob = cp.Problem(cp.Minimize(obj_quad + obj_lin + obj_reg))
    #     prob.solve(solver= "CLARABEL")
    #     return Theta_cp.value
    #     pass

    @staticmethod
    def prox_nuc(M, l):
        U, s, V_h = np.linalg.svd(M, full_matrices= False)
        s_shrink = np.maximum(0, s - l)
        return (U * s_shrink) @ V_h
    
    def l2sqr_val(self, Z):
        quad = np.einsum('ij,ijk,ik', Z, self.A, Z)
        lin = -2*np.einsum('ij,ij', Z, self.b)
        return quad + lin
    
    def l2sqr_linearized_val(self, Z1, Z2, lip):
        lin_part = np.einsum("ij,ij", Z1-Z2, self.l2sqr_grad(Z2))
        F_norm_sqr = np.einsum("ij,ij", Z1-Z2, Z1-Z2)
        return self.l2sqr_val(Z2) + lin_part + lip/2*F_norm_sqr
    
    def obj_val(self, Z):
        reg = self.l_nuc_t * np.linalg.norm(Z, 'nuc')
        return self.l2sqr_val(Z) + reg
    
    def obj_quasi_lin_val(self, Z1, Z2, lip):
        reg = self.l_nuc_t * np.linalg.norm(Z1, 'nuc')
        return self.l2sqr_linearized_val(Z1,Z2,lip) + reg


    def l2sqr_grad(self, Z):
        quad_grad = 2*np.einsum("ijk,ik->ij", self.A, Z)
        lin_grad = -2*self.b
        return quad_grad + lin_grad
    
class NetworkLassoAgent():
    """ Class of the network lasso policy
    """

    def __init__(self, adjacency, solver= nl.SolverPrimalDual(), first_term= nl.LstsqPrimalDual(), 
              l_net=1.0, beta=2.0,
              weighting= None, dtype='float32', K_norm='F'):
        """ initialization class

        Parameters
        ----------
        theta : array_like
            Estimated preference matrix
        adjacency : array_like
            graph adjacency matrix
        solver : _type_
            instance of the solver class used to solve the optimization problem at each time step 
        l_reg : float, optional
            regularization coefficient of the squared Frobenius norm, by default 0.0
        l_net : float, optional
            regularization coefficient of the total variation penalty, by default 1.0 # TODO set to theory inspired default value
        """
        # if edge_weighting == "random walk":
        #     adjacency = adjacency/np.sum(adjacency, axis=0)
        #     self.adjacency = 0.5*(adjacency + adjacency.T)
        if weighting == "degree":
            weights = 1/np.sqrt(np.sum(adjacency, axis=0))
            self.adjacency = weights * adjacency * weights[:,None]
        
        if weighting == "centrality":
            laplacian = np.diag(adjacency.sum(0)) - adjacency
            weights = np.sqrt(np.diag(np.linalg.pinv(laplacian)))
            weights /= max(weights)
            self.adjacency = weights * adjacency * weights[:,None]
        
        if weighting is None:
            self.adjacency = adjacency.astype(dtype)

        self.l_net = l_net

       
        self.solver = solver

        self.first_term = first_term
        self.dtype = dtype
        self.beta = beta
        self.K_norm = K_norm
        self.counts = np.zeros(self.adjacency.shape[0], dtype= 'int32')

    def play_arm(self, cxts, u, t):
        return np.argmax(cxts @ self.theta[u])

    def get_theta(self):
        return self.theta

    def initialize(self, bandit):
        """ get the needed information from the bandit

        Parameters
        ----------
        bandit : instance of the MultiTaskBandit class
            the bandit with which the agent will interact

        Returns
        -------
        _type_
            _description_
        """
        self.d = bandit.dim
        self.std = bandit.std
        
        self.theta = np.zeros(
            (self.adjacency.shape[0], self.d), dtype=self.dtype)
        
        self.solver.get_external_params(self.adjacency, self.theta)
        self.first_term.get_external_params(self.solver)
        self.A = np.zeros(
            (self.adjacency.shape[0], self.d, self.d), dtype=self.dtype)
        # squared Frobenius norm of X
        self.X_norm_F_sqr = np.zeros(self.adjacency.shape[0])
        self.B_pinv = utils.fast_pinv(self.solver.B.toarray())
        self.normB_pinv = np.max(np.abs(self.B_pinv))

        
        self.u_estim = np.zeros(
            (self.solver.B.shape[0], self.d), dtype=self.dtype)
        return self

    def update(self, x, y, u, t):
        self.counts[u] += 1
        # TODO solver update parameters
        self.first_term.update(u, x, y)
        # self.l_net = sqrt(t)
        self.A[u] += np.outer(x, x)
        self.X_norm_F_sqr[u] += np.dot(x, x)
        # spectral norm per user
        
        l_net_t = self._update_alpha(t)

        self.theta, self.u_estim = self.solver.solve(
            first_term= self.first_term, l_net=l_net_t, 
            theta0= self.theta, U0= self.u_estim)
        return self
    
    def _update_alpha(self,t):

        if self.K_norm == "F":
            A_norm_2 = np.linalg.norm(self.A, ord=2, axis=(1, 2))
            # sum of square F norms
            A_norm_F_sqr = np.einsum("uij,uij -> u", self.A, self.A)
            loginvdelta = self.beta*log(t)
            l_net_t = utils.sum_Hsu(self.X_norm_F_sqr.sum(), A_norm_F_sqr.sum(),
                                    np.max(A_norm_2), loginvdelta)
        
        if self.K_norm == "F_bound":
            loginvdelta = self.beta*log(t)
            l_net_t = t + 2*np.linalg.norm(self.counts)*sqrt(loginvdelta) 
            + 2*max(self.counts)*loginvdelta


        return self.l_net * self.std * sqrt(l_net_t)

# %%
