"""
Provides implementation of  semi-bandit algorithms
"""
import numpy as np
import random

EPS = 10E-12



class Bandit:
    """
    Interface for semi-bandit algorithms.
    """

    def __init__(self, action_set, dim, m_size , l_size):
        self.dim = dim
        self.m_size = m_size
        self.l_size = l_size
        if action_set == "full":
            self.unconstrained = True
        elif action_set == "m-set":
            self.unconstrained = False
        else:
            raise Exception("Invalid action set %s for OSMD, abort." % action_set)

    def next(self):
        raise NotImplementedError

    def reset(self) -> object:
        raise NotImplementedError

    def update(self, action, feedback):
        raise NotImplementedError

    def sample_action(self, x):
        """

        :param x: List[Float], marginal probabilities
        :return: combinatorial action
        """
        if self.unconstrained:
            return [i for i, val in enumerate(x) if random.random() < val]
        else:
            # m-set problem
            order = np.argsort(-x)
            included = np.copy(x[order])
            remaining = 1.0 - included
            outer_samples = [w for w in self.split_sample(included, remaining)]
            weights = list(map(lambda z: z[0], outer_samples))
            _, left, right = outer_samples[np.random.choice(len(outer_samples), p=weights)]
            if left == right - 1:
                sample = range(self.m_size)
            else:
                candidates = [i for i in range(left, right)]
                random.shuffle(candidates)
                sample = [i for i in range(left)] + candidates[:self.m_size - left]
            action = [order[i] for i in sample]
            return action

    def split_sample(self, included, remaining):
        """

        :param included: remaining marginal probabilities of sampling a coordinate
        :param remaining: remaining marginal probabilities of not sampling a coordinate
        :return: remaining sampling distributions
        """
        prop = 1.0
        left, right = 0, self.dim
        i = self.dim
        while left < right:
            i -= 1
            active = (self.m_size - left) / (right - left)
            inactive = 1.0 - active
            if active == 0 or inactive == 0:
                yield (prop, left, right)
                return
            weight = min(included[right - 1] / active, remaining[left] / inactive)
            yield weight, left, right
            prop -= weight
            assert prop >= -EPS
            included -= weight * active
            remaining -= weight * inactive
            while right > 0 and included[right - 1] <= EPS:
                right -= 1
            while left < self.dim and remaining[left] <= EPS:
                left += 1
            assert right - left <= i
        if prop > 0.0:
            yield (prop, self.m_size, self.m_size + 1)



class SortUCB(Bandit):

    def __init__(self, dim, action_set, m_size, l_size):
        super().__init__(action_set, dim, m_size, l_size)
        self.t = 0
        self.st = 0
        self.te = np.zeros(self.dim)  # the T(e) in the paper: number of observations of the base arm
        self.emp_sum = np.zeros(self.dim)
        self.l_size = l_size # L
        self.m_size = m_size # M
        self.order = []
        self.index_head = (self.l_size - 1)* self.m_size
        self.index_end = self.l_size* self.m_size
    def next(self, ):
        if self.t == 1:
            self.t += np.ceil(self.m_size/self.l_size)
            # print(range((self.num_sets-1)* self.m_size, self.num_sets* self.m_size))
            return range(self.index_head ,self.index_end)

        if len(self.order) != self.m_size:
            self.t += np.ceil(self.m_size/self.l_size)

            # print(self.te)
            conf_width = np.sqrt(np.divide(1.5 * np.log(self.t), self.te[self.index_head:self.index_end]))
            emp_avg = np.divide(self.emp_sum[self.index_head:self.index_end], self.te[self.index_head:self.index_end])
            upper_conf = emp_avg + conf_width
            lower_conf = emp_avg - conf_width


            order_ucb = np.argsort(upper_conf)

            for i in self.order:
                remove = np.where(order_ucb==i)
                order_ucb = np.delete(order_ucb, remove)

            if all(upper_conf[order_ucb[0]] <= i for i in lower_conf[order_ucb[1:]]):
               self.order.append(order_ucb[0])

            if len(self.order) == self.m_size:
                self.te = np.zeros(self.dim)
                self.emp_sum = np.zeros(self.dim)
                self.order = self.order[::-1]


            return range(self.index_head, self.index_end)
        else:
            # Period 2
            # Need to modified init period for both CombUCB and SortUCB
            self.st += 1
            if self.st <= int(self.dim + self.m_size - 1) / int(self.m_size):  
                if self.st* self.m_size <= self.dim:
                    return range((self.st - 1) * self.m_size, self.st * self.m_size)
                else:
                    return range(-self.m_size, 0)
            else:
                conf_width = np.sqrt(np.divide(1.5 * np.log(self.st), self.te))
                emp_avg = np.divide(self.emp_sum, self.te)
                return self.oracle(emp_avg)


    def oracle(self, emp_avg):
        l_max = 0
        ucb_max = 0

        for l in range(self.l_size):
            order_l = np.array(self.order)+ l*self.m_size
            ucb_l = emp_avg[order_l[0:l+1]].sum() + np.sqrt(np.divide(1.5 * np.log(self.t), self.te[order_l[0]]))

            if ucb_l >= ucb_max:
                ucb_max = ucb_l 
                l_max = l
        print((np.array(self.order)+ l_max*self.m_size)[0:l_max+1])
        return (np.array(self.order)+ l_max*self.m_size)[0:l_max+1]

    def update(self, action, feedback):
        for i in range(len(action)):
            arm = action[i]
            self.emp_sum[arm] += feedback[i]
            self.te[arm] += 1

    def reset(self):
        self.t = 0
        self.st = 0
        self.te = np.zeros(self.dim)  # the T(e) in the paper: number of observations of arm e
        self.ts = np.zeros(int(self.dim/self.m_size)) # the T(s) : number of observations of super arm after ranking.
        self.emp_sum = np.zeros(self.dim)
        self.order = []



class CombUCB(Bandit):
    """
    Implementation of CombUCB as a baseline algorithm
    """

    def __init__(self, dim, action_set, m_size):
        super().__init__(action_set, dim, m_size)
        self.t = 0
        self.te = np.zeros(self.dim)  # the T(e): number of observations of arm e
        self.emp_sum = np.zeros(self.dim)
        self.num_sets = int(self.dim/self.m_size) # number of sets.
        print(self.num_sets)
    def next(self, ):
        self.t += 1

        #need to modified
        #if self.t <= int(self.dim + self.m_size - 1) / int(self.m_size):  # m-set: explore all arms in ceil(d/m) rounds
        #    if self.t * self.m_size <= self.dim:
        #         # print(range((self.t - 1) * self.m_size, self.t * self.m_size))
        #        return range((self.t - 1) * self.m_size, self.t * self.m_size)
        #    else:
        #        return range(-self.m_size, 0)
        #else:
        conf_width = np.where(self.te == 0, 100, np.sqrt(0.5 * np.log(self.t) / self.te))
        emp_avg = np.where(self.te == 0, 0, np.divide(self.emp_sum, self.te))
        lower_conf = emp_avg + conf_width
        return self.oracle(lower_conf)

    def oracle(self, lower_conf):
        # order = np.argsort(lower_conf)  # increasing order
        # print(self.t, order)
        l_max = 0
        rewards_max = 0
        for l in range(self.num_sets):
            rewards_sum = 0
            for o in np.argsort(lower_conf[l*self.m_size:(l+1)*self.m_size])[::-1][0:l+1]:
                rewards_sum = rewards_sum + lower_conf[l*self.m_size+o]
            # print(rewards_sum)
            if rewards_sum > rewards_max:
                rewards_max = rewards_sum
                l_max = l
        
        print(np.argsort(lower_conf[l_max*self.m_size:(l_max+1)*self.m_size])[::-1][0:l_max+1]+ l_max*self.m_size)

        return np.argsort(lower_conf[l_max*self.m_size:(l_max+1)*self.m_size])[::-1][0:l_max+1]+ l_max*self.m_size

    def update(self, action, feedback):
        for i in range(len(action)):
            arm = action[i]
            self.emp_sum[arm] += feedback[i]
            self.te[arm] += 1

    def reset(self):
        self.t = 0
        self.te = np.zeros(self.dim)  # the T(e): number of observations of arm e
        self.emp_sum = np.zeros(self.dim)
        self.num_sets = int(self.dim/self.m_size) # number of sets.


class SortUCBSD(Bandit):

    def __init__(self, dim, action_set, m_size, l_size):
        super().__init__(action_set, dim, m_size, l_size)
        self.t = 0
        self.st = 0
        self.te = np.zeros(self.dim)  # the T(e) in the paper: number of observations of the base arm
        self.emp_sum = np.zeros(self.dim)
        self.l_size = l_size # L
        self.m_size = m_size # M
        self.order = []
        # CCC
        self.layer = 2
        self.index_head = (self.layer - 1)* self.m_size
        self.index_end = self.layer * self.m_size
        self.first_l_arms = []
    def next(self, ):
        # if self.t <= 7:
        #     self.t += 1
        #     print([self.index_head + self.t - 1])
        #     return [self.index_head + self.t - 1]
        # CCC
        if self.t == 0:
            self.t += 1
            print([self.m_size, self.m_size + 1])
            return [self.m_size, self.m_size + 1]
        if self.t == 1:
            self.t += 1
            print([self.m_size + 2, self.m_size + 3])
            return [self.m_size + 2, self.m_size + 3]
        if self.t == 2:
            self.t += 1
            print([self.m_size + 4, self.m_size + 5])
            return [self.m_size + 4, self.m_size + 5]
        
        if len(self.order) < self.m_size - self.l_size:
            self.t += 1

            # print(self.te)
            conf_width = np.sqrt(np.divide(0.4 * np.log(self.t), self.te[self.index_head:self.index_end]))
            emp_avg = np.divide(self.emp_sum[self.index_head:self.index_end], self.te[self.index_head:self.index_end])
            upper_conf = emp_avg + conf_width
            lower_conf = emp_avg - conf_width

            order_ucb = np.argsort(upper_conf) # 对UCB值进行从小到大的排序
            order_lcb = np.argsort(lower_conf) # 对LCB值进行从小到大的排序

            for i in self.order:
                remove = np.where(order_ucb==i)
                order_ucb = np.delete(order_ucb, remove)
                order_lcb = np.delete(order_lcb, remove)

            te_values_for_best_arms = self.te[order_ucb + self.m_size*(self.layer - 1)]
            # print("111", te_values_for_best_arms)
            sorted_indices = np.argsort(te_values_for_best_arms)
            # print("222", sorted_indices)
            # CCC
            order_pulled_slots = order_ucb[sorted_indices] + self.m_size*(self.layer - 1)
            # print("333", order_pulled_slots)

            # print("444", order_ucb)
            # print("---", upper_conf[order_ucb[0:]])
            # print("^^^", lower_conf[order_ucb[0:]])
            # print("555", self.te[self.m_size*(self.layer - 1): self.m_size*self.layer])

            if upper_conf[order_ucb[0]] <= lower_conf[order_lcb[self.m_size - self.l_size - len(self.order)]]:
               self.order.append(order_ucb[0]) # 已学完顺序，加入order
               print("$$$", self.order)

            # CCC
            # 获取order_pulled_slots中前self.num_sets个元素，并调整它们的索引
            adjusted_indices = order_pulled_slots[:self.layer]

            # 返回这些调整后索引的range形式
            # print("111", adjusted_indices)
            print(adjusted_indices)
            return adjusted_indices
        elif len(self.order) < self.m_size:
            conf_width = np.sqrt(np.divide(0.4 * np.log(self.t), self.te[self.index_head:self.index_end]))
            emp_avg = np.divide(self.emp_sum[self.index_head:self.index_end], self.te[self.index_head:self.index_end])
            upper_conf = emp_avg + conf_width
            lower_conf = emp_avg - conf_width
            # 计算UCB等参数

            order_ucb = np.argsort(upper_conf) # 对UCB值进行从小到大的排序
            order_lcb = np.argsort(lower_conf) # 对LCB值进行从小到大的排序
            bad_arms = set(self.order)

            # CCC
            if (len(self.order) == self.m_size - self.layer):
                self.first_l_arms = [i for i in range(self.m_size) if i not in bad_arms]
            
            for i in self.order:
                remove = np.where(order_ucb == i)
                order_ucb = np.delete(order_ucb, remove)
                order_lcb = np.delete(order_lcb, remove)
            # 移除已学完顺序的臂

            te_values_for_best_arms = self.te[order_ucb + self.m_size*(self.layer - 1)]
            # print("111", te_values_for_best_arms)
            sorted_indices = np.argsort(te_values_for_best_arms)
            # print("222", sorted_indices)
            order_pulled_slots = order_ucb[sorted_indices] + self.m_size*(self.layer - 1)
            # print("333", order_pulled_slots)

            # print("444", order_ucb)
            # print("---", upper_conf[order_ucb[0:]])
            # print("^^^", lower_conf[order_ucb[0:]])
            # print("555", self.te[4: 8])

            if all(upper_conf[order_ucb[0]] <= i for i in lower_conf[order_ucb[1:]]):
                self.order.append(order_ucb[0])

            # CCC
            adjusted_indices = order_pulled_slots[:self.layer]

            if len(self.order) == self.m_size:
                self.te = np.zeros(self.dim)
                self.emp_sum = np.zeros(self.dim)
                self.order = self.order[::-1]

            # CCC
            if (len(self.order) >= self.m_size - self.layer):
                print("111", len(self.order))
                print([x + self.m_size for x in self.first_l_arms])
                return [x + self.m_size for x in self.first_l_arms]
            else:
                print("111", len(self.order))
                print(adjusted_indices)
                return adjusted_indices
            # return [21, 23, 20, 24, 22]
        else:
            # Period 2
            # Need to modified init period for both CombUCB and SortUCB
            self.st += 1
            emp_avg = np.where(self.te == 0, 0, np.divide(self.emp_sum, self.te))
            return self.oracle(emp_avg)


    def oracle(self, emp_avg):
        l_max = 0
        ucb_max = 0
        for l in range(self.l_size):
            order_l = np.array(self.order)+ l * self.m_size
            # print("111", order_l)
            ucb_l = np.where(self.te[order_l[0]] == 0, 100, 
                 emp_avg[order_l[0: l + 1]].sum() + 
                 np.sqrt(np.divide(0.4 * (l + 1) * np.log(self.t), self.te[order_l[0]])))

            if ucb_l >= ucb_max:
                ucb_max = ucb_l 
                l_max = l
        
        print("return:", (np.array(self.order)+ l_max * self.m_size)[0:l_max+1])

        return (np.array(self.order)+ l_max * self.m_size)[0:l_max+1]


    def update(self, action, feedback):
        for i in range(len(action)):
            arm = action[i]
            self.emp_sum[arm] += feedback[i]
            self.te[arm] += 1

    def reset(self):
        self.t = 0
        self.st = 0
        self.te = np.zeros(self.dim)  # the T(e) in the paper: number of observations of arm e
        self.ts = np.zeros(int(self.dim/self.m_size)) # the T(s) : number of observations of super arm after ranking.
        self.emp_sum = np.zeros(self.dim)
        self.order = []