# -*- coding: utf-8 -*-
# final version

import datetime

import numpy as np

import algorithms.Base as Base
import env.Environment as Envi
from algorithms.LinUCB import LinUCB


class SCLUB(LinUCB):
    def __init__(self, user_num, d, T):
        super().__init__(user_num, d, T)
        self.clusters = {}
        self.cluster_inds = {}  # key:user_index, value:cluster_index
        self.num_clusters = np.zeros(self.rounds)  # the total number of clusters in each round
        self.init_clusters()

        self.phase_cardinality = 2
        self.T_phase = 0
        self.init_each_stage()
        self.alpha_cluster = 1.75
        self.alpha_prob = 1.75

    def init_clusters(self):
        self.clusters[0] = Base.sclub_Cluster(d=self.d,
                                              id_index=0,
                                              T=self.rounds,
                                              b=np.zeros(self.d),
                                              V=np.zeros((self.d, self.d)),
                                              t=0,
                                              user_num=self.usernum,
                                              users=list(range(self.usernum)),
                                              rewards=np.zeros(self.rounds),
                                              best_rewards=np.zeros(self.rounds),
                                              theta_phase=np.zeros(self.d),
                                              T_phase=0)
        for user_inx in range(self.usernum):
            self.cluster_inds[user_inx] = 0
        self.num_clusters[0] = 1

        # preparation per stage
    def init_each_stage(self):
        # mark every user unchecked for each cluster, update T and theta
        for _, cluster in self.clusters.items():
            cluster.checks = {j: False for j in cluster.users}
            cluster.checked = False
            cluster.phase_update()

    def locate_user_index(self, user_index):
        cluster_index = self.cluster_inds[user_index]
        return cluster_index

    def recommend(self, cluster_index, items):
        cluster = self.clusters[cluster_index]
        V_t, b_t, _ = cluster.get_info()
        M_t = np.eye(self.d) * self.lambda_ + V_t
        Minv = np.linalg.inv(M_t)
        theta = np.dot(Minv, b_t)
        # calculate the best item
        r_item_index = np.argmax(np.dot(items, theta) + self.alpha_ * (np.matmul(items, Minv) * items).sum(axis=1))
        return r_item_index

    def if_split(self, user_idx, cluster_idx, t):
        t1 = self.users[user_idx].t
        t2 = self.clusters[cluster_idx].t
        fact_T1 = np.sqrt((1 + np.log(1 + t1)) / (1 + t1))
        fact_T2 = np.sqrt((1 + np.log(1 + t2)) / (1 + t2))
        fact_t = np.sqrt((1 + np.log(1 + t)) / (1 + t))
        theta1 = self.users[user_idx].theta
        theta2 = self.clusters[cluster_idx].theta
        theta1 = theta1 / np.linalg.norm(theta1)
        theta2 = theta2 / np.linalg.norm(theta2)
        if np.linalg.norm(theta1 - theta2) > self.alpha_cluster * (fact_T1 + fact_T2):
            return True
        p1 = t1 / t
        for user2_idx in self.clusters[cluster_idx].users:
            if user2_idx == user_idx:
                continue
            p2 = self.users[user2_idx].t / t
            if np.abs(p1 - p2) > self.alpha_prob * 2 * fact_t:
                return True
        return False

    def if_merge(self, c1_idx, c2_idx, t):
        t1 = self.clusters[c1_idx].t
        t2 = self.clusters[c2_idx].t
        fact_T1 = np.sqrt((1 + np.log(1 + t1)) / (1 + t1))
        fact_T2 = np.sqrt((1 + np.log(1 + t2)) / (1 + t2))
        fact_t = np.sqrt((1 + np.log(1 + t)) / (1 + t))
        theta1 = self.clusters[c1_idx].theta
        theta2 = self.clusters[c2_idx].theta
        # normalize the theta to 1
        theta1 = theta1 / np.linalg.norm(theta1)
        theta2 = theta2 / np.linalg.norm(theta2)
        if np.linalg.norm(theta1 - theta2) >= self.alpha_cluster * (fact_T1 + fact_T2):
            return False
        p1 = self.cluster_aver_freq(c1_idx, t)
        p2 = self.cluster_aver_freq(c2_idx, t)
        if np.abs(p1 - p2) >= self.alpha_prob * 0.5 * fact_t:
            return False
        return True

    def cluster_aver_freq(self, c_idx, t):
        if len(self.clusters[c_idx].users) == 0:
            return 0
        return self.clusters[c_idx].t / (len(self.clusters[c_idx].users) * t)

    # find the available index for the new cluster
    def find_available_index(self):
        cmax = max(self.clusters)  # find the maximum index of the cluster
        for c1 in range(cmax + 1):
            if c1 not in self.clusters:
                return c1
        return cmax + 1

    def generate_new_cluster(self, cluster_idx, user_idx, T_phase, theta_phase):
        user = self.users[user_idx]
        tmp_cluster = Base.sclub_Cluster(d=self.d,
                                         id_index=cluster_idx,
                                         T=self.rounds,
                                         b=user.b,
                                         V=user.V,
                                         t=user.t,
                                         user_num=1,
                                         users=[user_idx],
                                         rewards=user.rewards,
                                         best_rewards=user.best_rewards,
                                         T_phase=T_phase,
                                         theta_phase=theta_phase)
        self.clusters[cluster_idx] = tmp_cluster
        self.cluster_inds[user_idx] = cluster_idx

    def split(self, user_idx, t):
        cls_idx = self.cluster_inds[user_idx]
        # if there is only one user in the cluster, no need to split
        self.clusters[cls_idx].update_check(user_idx)
        now_user = self.users[user_idx]
        if self.if_split(user_idx, cls_idx, t):
            new_cls_idx = self.find_available_index()
            self.generate_new_cluster(new_cls_idx, user_idx, self.clusters[cls_idx].T_phase, self.clusters[cls_idx].theta_phase)
            self.clusters[cls_idx].remove_user(user_idx, now_user)
            if len(self.clusters[cls_idx].users) == 0:
                del self.clusters[cls_idx]
        self.num_clusters[t - 1] = len(self.clusters)

    def merge(self, t):
        cmax = max(self.clusters)  # find the maximum index of the cluster
        for c1_idx in range(cmax + 1):
            if c1_idx not in self.clusters or not self.clusters[c1_idx].checked:
                continue
            for c2_idx in range(c1_idx + 1, cmax + 1):
                if c2_idx not in self.clusters or not self.clusters[c2_idx].checked:
                    continue
                if not self.if_merge(c1_idx, c2_idx, t):
                    continue
                for user_idx in self.clusters[c2_idx].users:
                    self.cluster_inds[user_idx] = c1_idx
                self.clusters[c1_idx].merge_cluster(self.clusters[c2_idx])
                del self.clusters[c2_idx]
                # print(f"at time {t}, merge cluster {c2_idx} to cluster {c1_idx}")
        self.num_clusters[t - 1] = len(self.clusters)

    def run(self, envir):
        reward_all_rounds = []  # to save feedback in each round
        items_all_rounds = []  # to save the recommended item in each round
        theta_hat = {}  # to save the users' final estimate theta information
        total_phases = np.int64(np.log(self.rounds) / np.log(self.phase_cardinality))
        t = 0
        self.starttime = datetime.datetime.now()
        for s in range(1, total_phases + 1):
            self.init_each_stage()
            for _ in range(1, self.phase_cardinality**s + 1):
                t += 1
                assert t <= self.rounds, f" {t} wrong time rounds"
                if t % 5000 == 0:
                    print("t = %d" % t)
                user_index = envir.generate_users()  # random user arrives
                cluster_index = self.locate_user_index(user_index)
                current_cluster = self.clusters[cluster_index]
                items = envir.get_items()
                r_item_index = self.recommend(cluster_index, items)
                selected_item = items[r_item_index]
                items_all_rounds.append(selected_item)
                self.reward[t - 1], instant_rwd, self.best_reward[t - 1] = envir.feedback(items=items, i=user_index, k=r_item_index)
                reward_all_rounds.append(instant_rwd)
                self.regret[t - 1] = self.best_reward[t - 1] - self.reward[t - 1]
                # update the user's information
                self.users[user_index].store_info(selected_item, instant_rwd, t - 1, self.reward[t - 1], self.best_reward[t - 1])
                # update the cluster's information
                current_cluster.store_info(selected_item, instant_rwd, t - 1, self.reward[t - 1], self.best_reward[t - 1])
                # check update
                self.split(user_index, t)
                # check merge
                self.merge(t)

                cluster_num = 0
                if t == self.rounds:
                    cluster_num = self.num_clusters[t - 1]
                    for user_idx in range(self.usernum):
                        theta_hat[user_idx] = self.users[user_idx].theta
                    print("test time rounds", t, self.rounds)
                    self.endtime = datetime.datetime.now()
                    self.runtime = self.endtime - self.starttime
                    break
            if t == self.rounds:
                break

        final_results = {}
        final_results["runtime"] = self.runtime
        final_results["regret"] = self.regret
        final_results["theta"] = theta_hat
        final_results["reward"] = self.reward
        final_results["items_all_rounds"] = items_all_rounds
        final_results["reward_all_rounds"] = reward_all_rounds
        final_results["cluster_num"] = cluster_num
        print("final cluster number: ", cluster_num)
        print({cidx: list(cluster.users) for cidx, cluster in self.clusters.items()})
        print({cidx: cluster.t for cidx, cluster in self.clusters.items()})
        return final_results
