# -*- coding: utf-8 -*-
# final version

import datetime

import numpy as np

import algorithms.Base as Base
import env.Environment as Envi
from algorithms.SCLUB import SCLUB


class UniSCorn(SCLUB):
    def __init__(self, user_num, d, T, dist=5):
        super().__init__(user_num, d, T)
        self.distances = dist
        self.exploration_time = int(1 * self.d / (0.5 * 5.2**2) * np.log(self.rounds * self.d * self.usernum))

        self.alpha_uniscorn = 1.3
        self.alpha_cluster = 1.75
        self.alpha_switch = 1.1

    def set_exploratrion_time(self, exp_time):
        self.exploration_time = exp_time

    def check_switch(self, t):
        switch = True
        for _, user in self.users.items():
            _, _, picked_t = user.get_info()
            if picked_t == 0:
                switch = False
                break
            CB = self.alpha_uniscorn * np.sqrt((1 + np.log(1 + picked_t)) / (1 + picked_t))
            # if t > self.exploration_time:
            #     switch = True
            #     break
            if CB > self.distances / 4:
                switch = False
                break
        if switch:
            print("switch at time: ", t)
        return switch

    def run(self, envir):
        reward_all_rounds = []  # to save feedback in each round
        items_all_rounds = []  # to save the recommended item in each round
        theta_hat = {}  # to save the users' final estimate theta information
        switch_flag = False
        total_phases = np.int64(np.log(self.rounds) / np.log(self.phase_cardinality))
        t = 0
        self.starttime = datetime.datetime.now()
        for s in range(1, total_phases + 1):
            self.init_each_stage()
            for _ in range(1, self.phase_cardinality**s + 1):
                t += 1
                assert t <= self.rounds, f" {t} wrong time rounds"
                if t % 1000 == 0:
                    print("t = %d" % t)
                user_index = envir.generate_users()  # random user arrives
                cluster_index = self.locate_user_index(user_index)
                current_cluster = self.clusters[cluster_index]
                items = envir.get_items()
                if not switch_flag:
                    r_item_index = np.random.choice(range(len(items)))
                else:
                    r_item_index = self.recommend(cluster_index, items)
                selected_item = items[r_item_index]
                items_all_rounds.append(selected_item)
                self.reward[t - 1], instant_rwd, self.best_reward[t - 1] = envir.feedback(items=items, i=user_index, k=r_item_index)
                reward_all_rounds.append(instant_rwd)
                self.regret[t - 1] = self.best_reward[t - 1] - self.reward[t - 1]
                # update the user's information
                self.users[user_index].store_info(selected_item, instant_rwd, t - 1, self.reward[t - 1], self.best_reward[t - 1])
                # update the cluster's information
                current_cluster.store_info(selected_item, instant_rwd, t - 1, self.reward[t - 1], self.best_reward[t - 1])
                # check update
                self.split(user_index, t)
                # check merge
                self.merge(t)
                if not switch_flag:
                    switch_flag = self.check_switch(t - 1)

                cluster_num = 0
                if t == self.rounds:
                    cluster_num = self.num_clusters[t - 1]
                    for user_idx in range(self.usernum):
                        theta_hat[user_idx] = self.users[user_idx].theta
                    print("test time rounds", t, self.rounds)
                    self.endtime = datetime.datetime.now()
                    self.runtime = self.endtime - self.starttime
                    break
            if t == self.rounds:
                break

        final_results = {}
        final_results["runtime"] = self.runtime
        final_results["regret"] = self.regret
        final_results["theta"] = theta_hat
        final_results["reward"] = self.reward
        final_results["items_all_rounds"] = items_all_rounds
        final_results["reward_all_rounds"] = reward_all_rounds
        final_results["cluster_num"] = cluster_num
        print("final cluster number: ", cluster_num)
        print({cidx: list(cluster.users) for cidx, cluster in self.clusters.items()})
        print({cidx: cluster.t for cidx, cluster in self.clusters.items()})
        return final_results
