# -*- coding: utf-8 -*-
# final version

import copy
import datetime

import networkx as nx
import numpy as np

import algorithms.Base as Base
import env.Environment as Envi
from algorithms.LinUCB import LinUCB
from utils.utils import set_random_seed


class CLUB(LinUCB):
    def __init__(self, user_num, d, T):
        super().__init__(user_num, d, T)
        self.clusters = {}
        self.cluster_inds = {}  # key:user_index, value:cluster_index
        self.num_clusters = np.zeros(self.rounds)  # the total number of clusters in each round
        self.init_clusters()
        # generate a complete graph for all users
        self.G = nx.complete_graph(user_num)
        self.alpha_cluster = 1.75

    # initialize the cluster info at the beginning of the algorithm
    def init_clusters(self):
        self.clusters[0] = Base.Cluster(d=self.d,
                                        id_index=0,
                                        T=self.rounds,
                                        b=np.zeros(self.d),
                                        V=np.zeros((self.d, self.d)),
                                        t=0,
                                        user_num=self.usernum,
                                        users=list(range(self.usernum)),
                                        rewards=np.zeros(self.rounds),
                                        best_rewards=np.zeros(self.rounds))
        for user_inx in range(self.usernum):
            self.cluster_inds[user_inx] = 0
        self.num_clusters[0] = 1

    def locate_user_index(self, user_index):
        cluster_index = self.cluster_inds[user_index]
        return cluster_index

    def recommend(self, cluster_index, items):
        cluster = self.clusters[cluster_index]
        V_t, b_t, _ = cluster.get_info()
        M_t = np.eye(self.d) * self.lambda_ + V_t
        Minv = np.linalg.inv(M_t)
        theta = np.dot(Minv, b_t)
        # calculate the best item
        r_item_index = np.argmax(np.dot(items, theta) + self.alpha_ * (np.matmul(items, Minv) * items).sum(axis=1))
        return r_item_index

    # check whether the edge between the two users in this cluster needs to be deleted
    def if_delete(self, user_index1, user_index2):
        user1 = self.users[user_index1]
        user2 = self.users[user_index2]
        t1 = user1.t
        t2 = user2.t
        fact_T1 = np.sqrt((1 + np.log(1 + t1)) / (1 + t1))
        fact_T2 = np.sqrt((1 + np.log(1 + t2)) / (1 + t2))
        theta1 = user1.theta
        theta2 = user2.theta
        # normalize the theta to 1
        theta1 = theta1 / np.linalg.norm(theta1)
        theta2 = theta2 / np.linalg.norm(theta2)
        return np.linalg.norm(theta1 - theta2) > self.alpha_cluster * (fact_T1 + fact_T2)

    def generate_new_cluster(self, cluster_idx, users_list):
        print("cluster index: ", cluster_idx)
        print("users_list: ", users_list)
        temp_b = sum([self.users[k].b for k in users_list])
        temp_V = sum([self.users[k].V for k in users_list])
        temp_t = sum([self.users[k].t for k in users_list])
        temp_user_num = len(users_list)
        temp_id = cluster_idx
        temp_reward = sum([self.users[k].rewards for k in users_list])
        temp_best_reward = sum([self.users[k].best_rewards for k in users_list])
        temp_cluster = Base.Cluster(d=self.d,
                                    id_index=temp_id,
                                    T=self.rounds,
                                    b=temp_b,
                                    V=temp_V,
                                    t=temp_t,
                                    user_num=temp_user_num,
                                    users=users_list,
                                    rewards=temp_reward,
                                    best_rewards=temp_best_reward)
        self.clusters[cluster_idx] = temp_cluster

    # Delete edges in this user's cluster
    def update(self, user_idx, t):
        update_cluster = False  # if cluster updates, may split
        # Find the local cluster of the user
        clst_idx = self.cluster_inds[user_idx]
        # find the neighbors of the user
        user_neighbors = [a for a in self.G.neighbors(user_idx)]

        for user2_index in user_neighbors:
            user1 = self.users[user_idx]
            user2 = self.users[user2_index]
            if user1.t != 0 and user2.t != 0 and self.if_delete(user_idx, user2_index):
                self.G.remove_edge(user_idx, user2_index)  # delete the edge
                update_cluster = True
                # print("delete edge between user ", user_idx, " and user ", user2_index)

        if update_cluster:
            # find the connected components of the graph containing user(user_idx), which is the clusters may be split
            # CC： set of nodes in a connected component
            CC = nx.node_connected_component(self.G, user_idx)
            # user waiting to be assi1gned to a new cluster
            connected_users = list(CC)
            # print("connected users: ", connected_users)
            if len(CC) < len(self.clusters[clst_idx].users):
                # the new component is a subset of the original cluster
                original_user_indices = self.clusters[clst_idx].users
                # generate a new cluster
                self.generate_new_cluster(clst_idx, connected_users)
                # Remove the users constituting the new cluster from the origin cluster's userlist
                remaining_users = [i for i in original_user_indices if i not in connected_users]
                new_clst_idx = max(self.clusters) + 1
                while len(remaining_users) > 0:  # having users left
                    # j = np.random.choice(remaining_users)
                    j = remaining_users[0]
                    print("random user: ", j)
                    CC = nx.node_connected_component(self.G, j)
                    new_cluster_users = list(CC)
                    self.generate_new_cluster(new_clst_idx, new_cluster_users)
                    for k in CC:
                        self.cluster_inds[k] = new_clst_idx
                    new_clst_idx += 1
                    remaining_users = [i for i in remaining_users if i not in new_cluster_users]
        self.num_clusters[t] = len(self.clusters)  # update the number of cluster

    # CLUB
    def run(self, envir):
        reward_all_rounds = []  # to save feedback in each round
        items_all_rounds = []  # to save the recommended item in each round
        theta_hat = {}  # to save the users' final theta information
        self.starttime = datetime.datetime.now()
        for i in range(1, self.rounds + 1):
            if i % 5000 == 0:
                print(i)
            user_index = envir.generate_users()  # random user arrives
            cluster_index = self.locate_user_index(user_index)
            current_cluster = self.clusters[cluster_index]
            # the context set
            items = envir.get_items()
            r_item_index = self.recommend(cluster_index, items)
            selected_item = items[r_item_index]
            items_all_rounds.append(selected_item)
            self.reward[i - 1], instant_rwd, self.best_reward[i - 1] = envir.feedback(items=items, i=user_index, k=r_item_index)
            reward_all_rounds.append(instant_rwd)

            # update the user's information
            self.users[user_index].store_info(selected_item, instant_rwd, i - 1, self.reward[i - 1], self.best_reward[i - 1])
            # update the cluster's information
            current_cluster.store_info(selected_item, instant_rwd, i - 1, self.reward[i - 1], self.best_reward[i - 1])
            self.update(user_index, i - 1)
            self.regret[i - 1] = self.best_reward[i - 1] - self.reward[i - 1]

            # get all users' theta
            cluster_num = 0
            if i == self.rounds:
                cluster_num = self.num_clusters[i - 1]
                for user_idx in range(self.usernum):
                    theta_hat[user_idx] = self.users[user_idx].theta
        self.endtime = datetime.datetime.now()
        self.runtime = self.endtime - self.starttime
        final_results = {}
        final_results["runtime"] = self.runtime
        final_results["regret"] = self.regret
        final_results["theta"] = theta_hat
        final_results["reward"] = self.reward
        final_results["items_all_rounds"] = items_all_rounds
        final_results["reward_all_rounds"] = reward_all_rounds
        final_results["cluster_num"] = cluster_num

        print("final cluster number: ", cluster_num)
        print({cidx: list(cluster.users) for cidx, cluster in self.clusters.items()})
        return final_results
