# -*- coding: utf-8 -*-
# final version

import copy

import numpy as np


class User:
    def __init__(self, d, id_index, T):
        self.d = d  # dimension
        self.index = id_index  # the user's index, and it's unique
        self.t = 0  # rounds that pick the user
        self.b = np.zeros(self.d)
        self.V = np.zeros((self.d, self.d))
        self.rewards = np.zeros(T)  # T: the total round
        self.best_rewards = np.zeros(T)
        self.theta = np.zeros(d)

    def store_info(self, x, y, t, r, br):
        self.t += 1
        self.V = self.V + np.outer(x, x)
        self.b = self.b + y * x
        self.rewards[t] += r
        self.best_rewards[t] += br
        self.theta = np.matmul(np.linalg.inv(np.eye(self.d) + self.V), self.b)

    def get_info(self):
        return self.V, self.b, self.t


# Base cluster in CLUB/UniCLUB
class Cluster(User):
    def __init__(self, d, id_index, T, b, V, t, user_num, users, rewards, best_rewards):
        super(Cluster, self).__init__(d, id_index, T)
        self.b = b
        self.t = t  # sum of picked numbers of all users in this cluster
        self.V = V
        self.user_num = user_num
        if users is None:
            self.users = list[range(user_num)]
        else:
            self.users = users
        self.rewards = rewards
        self.best_rewards = best_rewards
        self.theta = np.matmul(np.linalg.inv(np.eye(self.d) + self.V), self.b)


# cluster in SCLUB/UniSCLUB
class sclub_Cluster(Cluster):
    def __init__(self, d, id_index, T, b, V, t, user_num, users, rewards, best_rewards, theta_phase, T_phase=0):
        super(sclub_Cluster, self).__init__(d, id_index, T, b, V, t, user_num, users, rewards, best_rewards)
        self.T_phase = T_phase
        self.theta_phase = theta_phase
        self.checks = {i: False for i in self.users}
        self.checked = len(self.users) == sum(self.checks.values())

    def phase_update(self):
        self.T_phase = self.t
        self.theta_phase = self.theta

    def update_check(self, i):
        self.checks[i] = True
        self.checked = len(self.users) == sum(self.checks.values())

    def remove_user(self, user_idx, user):
        self.users.remove(user_idx)
        self.checks.pop(user_idx)
        self.checked = len(self.users) == sum(self.checks.values())
        self.user_num -= 1
        self.t -= user.t
        self.V -= user.V
        self.b -= user.b
        self.rewards -= user.rewards
        self.best_rewards -= user.best_rewards
        self.theta = np.matmul(np.linalg.inv(np.eye(self.d) + self.V), self.b)

    def merge_cluster(self, cluster):
        self.users.extend(cluster.users)
        self.user_num += cluster.user_num
        self.t += cluster.t
        self.V += cluster.V
        self.b += cluster.b
        self.rewards += cluster.rewards
        self.best_rewards += cluster.best_rewards
        self.theta = np.matmul(np.linalg.inv(np.eye(self.d) + self.V), self.b)
        self.checks.update(cluster.checks)
        self.checked = len(self.users) == sum(self.checks.values())
