from collections import defaultdict
import numpy as np
import random
import sys 
import networkx as nx



class Cluster:
    def __init__(self, users, S, b, N):
        self.users = set(users) # a list/array of users
        self.S = S
        self.b = b
        self.N = N
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)


class LOCB:
    # each user is an independent LinUCB
    def __init__(self, nu, d, gamma, num_seeds, delta, detect_cluster):
        self.S = {i:np.eye(d) for i in range(nu)}
        self.b = {i:np.zeros(d) for i in range(nu)}
        self.Sinv = {i:np.eye(d) for i in range(nu)}
        self.theta = {i:np.zeros(d) for i in range(nu)}
        self.users = range(nu)
        
        self.seeds = np.random.choice(self.users, num_seeds)
        self.seed_state = {}
        for seed in self.seeds:
            self.seed_state[seed] = 0
        self.clusters = {}
        for seed in self.seeds: 
            self.clusters[seed] = Cluster(users=self.users, S=np.eye(d), b=np.zeros(d), N=1)
            
        self.N = np.zeros(nu)
        self.gamma = gamma
        self.results = []
        self.fin = 0
        self.cluster_inds = {i:[] for i in range(nu)}
        for i in self.users:
            for seed in self.seeds:
                if i in self.clusters[seed].users:
                    self.cluster_inds[i].append(seed)
                    
        self.d = d
        self.n = nu
        self.selected_cluster = 0 
        self.delta = delta
        self.if_d = detect_cluster
        
    def _beta(self, N, t):
        return np.sqrt(self.d * np.log(1 + N / self.d) + 4 * np.log(t) + np.log(2)) + 1

    def _select_item_ucb(self, S, Sinv, theta, items, N, t):
        return np.argmax(np.dot(items, theta) + self._beta(N, t) * (np.matmul(items, Sinv) * items).sum(axis = 1))


    def _update_inverse(self, S, b, Sinv, x, t):
        Sinv = np.linalg.inv(S)
        theta = np.matmul(Sinv, b)
        return Sinv, theta
    

    def recommend(self, i, items, t):
        cls = self.cluster_inds[i]
        if (len(cls)>0) and (t <40000):
            res = []
            for c in cls:
                cluster = self.clusters[c]
                res_sin = self._select_item_ucb(cluster.S,cluster.Sinv, cluster.theta, items, cluster.N, t)
                res.append(res_sin)
            best_cluster = max(res)
            return best_cluster[1]
        else:
            no_cluster = self._select_item_ucb(self.S[i], self.Sinv[i], self.theta[i], items, self.N[i], t)
            return no_cluster[1]
  
        
    def _select_item_ucb(self, S, Sinv, theta, items, N, t):
        ucbs = np.dot(items, theta) + self._beta(N, t) * (np.matmul(items, Sinv) * items).sum(axis = 1)
        res = max(ucbs)
        it = np.argmax(ucbs)
        return (res, it)
        

    def store_info(self, i, x, y, t):
        
        self.S[i] += np.outer(x, x)
        self.b[i] += y * x
        self.N[i] += 1

        self.Sinv[i], self.theta[i] = self._update_inverse(self.S[i], self.b[i], self.Sinv[i], x, self.N[i])
        
        for c in self.cluster_inds[i]:
            self.clusters[c].S += np.outer(x, x)
            self.clusters[c].b += y * x
            self.clusters[c].N += 1
            self.clusters[c].Sinv = np.linalg.inv(self.clusters[c].S)
            self.clusters[c].theta = np.matmul(self.clusters[c].Sinv, self.clusters[c].b)
            

        
    
    def update(self, i, t):
        def _factT(m):
            if self.if_d:
                delta = self.delta / self.n
                nu = np.sqrt(2*self.d*np.log(1 + t) + 2*np.log(2/delta)) +1
                de = np.sqrt(1+m/4)*np.power(self.n, 1/3)
                return nu/de
            else:
                return np.sqrt((1 + np.log(1 + m)) / (1 + m))
        
        if not self.fin:

              
            for seed in self.seeds:
                if not self.seed_state[seed]:
                    if i in self.clusters[seed].users:
                        diff = self.theta[i] - self.theta[seed]
                        if np.linalg.norm(diff) > _factT(self.N[i]) + _factT(self.N[seed]):
                            self.clusters[seed].users.remove(i)
                            self.cluster_inds[i].remove(seed)                            
                            self.clusters[seed].S = self.clusters[seed].S - self.S[i] + np.eye(self.d)
                            self.clusters[seed].b = self.clusters[seed].b - self.b[i]
                            self.clusters[seed].N = self.clusters[seed].N - self.N[i]
                            
                    else:
                        diff = self.theta[i] - self.theta[seed]
                        if np.linalg.norm(diff) < _factT(self.N[i]) + _factT(self.N[seed]):
                            self.clusters[seed].users.add(i)
                            self.cluster_inds[i].append(seed)
                            self.clusters[seed].S = self.clusters[seed].S + self.S[i] - np.eye(self.d)
                            self.clusters[seed].b = self.clusters[seed].b + self.b[i]
                            self.clusters[seed].N = self.clusters[seed].N + self.N[i]
                
                    if self.if_d: thre = self.gamma 
                    else: thre = self.gamma/4
                        
                    if _factT(self.N[seed]) <= thre:
                        self.seed_state[seed] = 1
                        self.results.append({seed:list(self.clusters[seed].users)}) 

            finished = 1
            for i in self.seed_state.values():
                if i ==0:
                    finished =0
                    
            if finished: 
                if self.if_d:
                    np.save('./results/clusters', self.results)
                    print('Clustering finished! Round:', t)
                    self.stop = 1
                self.fin = 1