# -*- coding: utf-8 -*-
# final version

import datetime

import networkx as nx
import numpy as np

import env.Environment as Envi
from algorithms import Base
from algorithms.CLUB import CLUB
from algorithms.LinUCB import LinUCB


class UniCorn(CLUB):
    def __init__(self, user_num, d, T, dist=5):
        super().__init__(user_num, d, T)
        self.distances = dist
        self.exploration_time = int(1 * self.d / (0.5 * 5.2**2) * np.log(self.rounds * self.d * self.usernum))

        self.alpha_unicorn = 1.4
        self.alpha_cluster = 1.75
        self.alpha_switch = 1.1

    def set_exploratrion_time(self, exp_time):
        self.exploration_time = exp_time

    def if_add(self, user_index1, user_index2):
        user1 = self.users[user_index1]
        user2 = self.users[user_index2]
        t1 = user1.t
        t2 = user2.t
        fact_T1 = np.sqrt((1 + np.log(1 + t1)) / (1 + t1))
        fact_T2 = np.sqrt((1 + np.log(1 + t2)) / (1 + t2))
        theta1 = user1.theta
        theta2 = user2.theta
        # normalize the theta to 1
        theta1 = theta1 / np.linalg.norm(theta1)
        theta2 = theta2 / np.linalg.norm(theta2)
        flag1 = np.linalg.norm(theta1 - theta2) >= self.alpha_cluster / 2 * (fact_T1 + fact_T2)
        if not flag1:
            return False
        return True

    def check_switch(self, t):
        switch = True
        for idx, user in self.users.items():
            _, _, picked_t = user.get_info()
            if picked_t == 0:
                switch = False
                break
            CB = self.alpha_unicorn * np.sqrt((1 + np.log(1 + picked_t)) / (1 + picked_t))
            print("user idx: ", idx, "picked_t: ", picked_t, "CB: ", CB)
            # if t > self.exploration_time:
            #     switch = True
            #     break
            if CB > self.distances / 4:
                switch = False
                break
        if switch:
            print("switch at time: ", t)
        return switch

    def run(self, envir):
        reward_all_rounds = []
        items_all_rounds = []
        theta_hat = {}
        self.starttime = datetime.datetime.now()
        switch_flag = False
        for i in range(1, self.rounds + 1):
            if i % 1000 == 0:
                print("round: ", i)
            # get the user index
            user_index = envir.generate_users()
            items = envir.get_items()
            cluster_index = self.locate_user_index(user_index)
            if not switch_flag:
                r_item_index = np.random.choice(range(len(items)))
            else:
                r_item_index = self.recommend(cluster_index, items)
            selected_item = items[r_item_index]
            items_all_rounds.append(selected_item)
            self.reward[i - 1], instant_rwd, self.best_reward[i - 1] = envir.feedback(items=items, i=user_index, k=r_item_index)
            reward_all_rounds.append(instant_rwd)
            # update the user's information
            self.users[user_index].store_info(selected_item, instant_rwd, i - 1, self.reward[i - 1], self.best_reward[i - 1])
            self.clusters[cluster_index].store_info(selected_item, instant_rwd, i - 1, self.reward[i - 1], self.best_reward[i - 1])
            self.regret[i - 1] = self.best_reward[i - 1] - self.reward[i - 1]
            self.update(user_index, i - 1)
            # delete the edge
            if not switch_flag:
                switch_flag = self.check_switch(i - 1)

            cluster_num = 0
            if i == self.rounds:
                cluster_num = len(self.clusters)
                for user_idx in range(self.usernum):
                    theta_hat[user_idx] = self.users[user_idx].theta
        self.endtime = datetime.datetime.now()
        self.runtime = self.endtime - self.starttime
        final_results = {}
        final_results["runtime"] = self.runtime
        final_results["regret"] = self.regret
        final_results["theta"] = theta_hat
        final_results["reward"] = self.reward
        final_results["items_all_rounds"] = items_all_rounds
        final_results["reward_all_rounds"] = reward_all_rounds
        final_results["cluster_num"] = cluster_num

        print("final cluster number: ", cluster_num)
        print({cidx: list(cluster.users) for cidx, cluster in self.clusters.items()})
        return final_results
