"""
Provides implementation of  semi-bandit algorithms
"""
import numpy as np
import random

EPS = 10E-12


class Bandit:
    """
    Interface for semi-bandit algorithms.
    """

    def __init__(self, action_set, dim, m_size):
        self.dim = dim
        self.m_size = m_size
        if action_set == "full":
            self.unconstrained = True
        elif action_set == "m-set":
            self.unconstrained = False
        else:
            raise Exception("Invalid action set %s for OSMD, abort." % action_set)

    def next(self):
        raise NotImplementedError

    def reset(self) -> object:
        raise NotImplementedError

    def update(self, action, feedback):
        raise NotImplementedError

    def sample_action(self, x):
        """

        :param x: List[Float], marginal probabilities
        :return: combinatorial action
        """
        if self.unconstrained:
            return [i for i, val in enumerate(x) if random.random() < val]
        else:
            # m-set problem
            order = np.argsort(-x)
            included = np.copy(x[order])
            remaining = 1.0 - included
            outer_samples = [w for w in self.split_sample(included, remaining)]
            weights = list(map(lambda z: z[0], outer_samples))
            _, left, right = outer_samples[np.random.choice(len(outer_samples), p=weights)]
            if left == right - 1:
                sample = range(self.m_size)
            else:
                candidates = [i for i in range(left, right)]
                random.shuffle(candidates)
                sample = [i for i in range(left)] + candidates[:self.m_size - left]
            action = [order[i] for i in sample]
            return action

    def split_sample(self, included, remaining):
        """

        :param included: remaining marginal probabilities of sampling a coordinate
        :param remaining: remaining marginal probabilities of not sampling a coordinate
        :return: remaining sampling distributions
        """
        prop = 1.0
        left, right = 0, self.dim
        i = self.dim
        while left < right:
            i -= 1
            active = (self.m_size - left) / (right - left)
            inactive = 1.0 - active
            if active == 0 or inactive == 0:
                yield (prop, left, right)
                return
            weight = min(included[right - 1] / active, remaining[left] / inactive)
            yield weight, left, right
            prop -= weight
            assert prop >= -EPS
            included -= weight * active
            remaining -= weight * inactive
            while right > 0 and included[right - 1] <= EPS:
                right -= 1
            while left < self.dim and remaining[left] <= EPS:
                left += 1
            assert right - left <= i
        if prop > 0.0:
            yield (prop, self.m_size, self.m_size + 1)



class SortUCB(Bandit):

    def __init__(self, dim, action_set, m_size):
        super().__init__(action_set, dim, m_size)
        self.t = 0
        self.st = 0
        self.te = np.zeros(self.dim)  # the T(e) in the paper: number of observations of arm e
        self.emp_sum = np.zeros(self.dim)
        self.num_sets = int(self.dim/m_size) # number of sets.
        self.order = []
        self.first_l_arms = []

    def next(self, ):
        self.t += 1
        if self.t == 1:
            # print(range((self.num_sets-1)* self.m_size, self.num_sets* self.m_size))
            return range(self.dim-self.m_size ,self.dim)

        if len(self.order) < self.m_size - self.num_sets:
            conf_width = np.sqrt(np.divide(0.5 * np.log(self.t), self.te[self.dim-self.m_size:self.dim]))
            emp_avg = np.divide(self.emp_sum[self.dim-self.m_size :self.dim], self.te[self.dim-self.m_size: self.dim])
            upper_conf = emp_avg + conf_width
            lower_conf = emp_avg - conf_width
            # 计算UCB等参数

            order_ucb = np.argsort(upper_conf) # 对UCB值进行从小到大的排序
            order_lcb = np.argsort(lower_conf) # 对LCB值进行从小到大的排序
            order_pulled_slots = np.argsort(self.te[self.dim - self.m_size: self.dim]) # 对最后M个T(e)进行排序

            for i in self.order:
                remove = np.where(order_ucb == i)
                order_ucb = np.delete(order_ucb, remove)
                order_lcb = np.delete(order_lcb, remove)
                order_pulled_slots = np.delete(order_pulled_slots, remove)
            # 移除已学完顺序的臂

            # print("***", order_ucb)
            # print("---", upper_conf[order_ucb[0]], lower_conf[order_ucb[1:]], lower_conf)

            if upper_conf[order_ucb[0]] <= lower_conf[order_lcb[self.m_size - self.num_sets - len(self.order)]]:
               self.order.append(order_ucb[0]) # 已学完顺序，加入order
               print("$$$", self.order)

            # 获取order_pulled_slots中前self.num_sets个元素，并调整它们的索引
            adjusted_indices = order_pulled_slots[:self.num_sets] + self.dim - self.m_size

            # 返回这些调整后索引的range形式
            # print("111", adjusted_indices)
            print(adjusted_indices)
            return adjusted_indices

        elif len(self.order) < self.m_size:
            conf_width = np.sqrt(np.divide(0.5 * np.log(self.t), self.te[self.dim-self.m_size:self.dim]))
            emp_avg = np.divide(self.emp_sum[self.dim-self.m_size :self.dim], self.te[self.dim-self.m_size: self.dim])
            upper_conf = emp_avg + conf_width
            lower_conf = emp_avg - conf_width
            # 计算UCB等参数

            order_ucb = np.argsort(upper_conf) # 对UCB值进行从小到大的排序
            bad_arms = set(self.order)
            if (len(self.order) == self.m_size - self.num_sets):
                self.first_l_arms = [i for i in range(self.m_size) if i not in bad_arms]
            
            for i in self.order:
                remove = np.where(order_ucb == i)
                order_ucb = np.delete(order_ucb, remove)
            # 移除已学完顺序的臂

            # print("***", order_ucb)
            # print("---",upper_conf[order_ucb[0]], lower_conf[order_ucb[1:]], lower_conf)

            if all(upper_conf[order_ucb[0]] <= i for i in lower_conf[order_ucb[1:]]):
                self.order.append(order_ucb[0])

            if len(self.order) == self.m_size:
                self.te = np.zeros(self.dim)
                self.emp_sum = np.zeros(self.dim)
                self.order = self.order[::-1]
            print([x + (self.dim - self.m_size) for x in self.first_l_arms])
            return [x + (self.dim - self.m_size) for x in self.first_l_arms]
            # return [21, 23, 20, 24, 22]
            
        else:
            self.st += 1
            # print("111", self.order)
            # if self.st <= int(self.dim + self.m_size - 1) / int(self.m_size):  
            #     if self.st * self.m_size <= self.dim:
            #         print(range((self.st - 1) * self.m_size, self.st * self.m_size))
            #         return range((self.st - 1) * self.m_size, self.st * self.m_size)
            #     else:
            #         return range(-self.m_size, 0)
            # else:
            emp_avg = np.where(self.te == 0, 0, np.divide(self.emp_sum, self.te))
            return self.oracle(emp_avg)


    def oracle(self, emp_avg):
        l_max = 0
        ucb_max = 0
        for l in range(self.num_sets):
            order_l = np.array(self.order)+ l * self.m_size
            # print("111", order_l)
            ucb_l = np.where(self.te[order_l[0]] == 0, 100, 
                 emp_avg[order_l[0: l + 1]].sum() + 
                 np.sqrt(np.divide(0.5 * (l + 1) * np.log(self.t), self.te[order_l[0]])))

            if ucb_l >= ucb_max:
                ucb_max = ucb_l 
                l_max = l
        
        print("return:", (np.array(self.order)+ l_max * self.m_size)[0:l_max+1])

        return (np.array(self.order)+ l_max * self.m_size)[0:l_max+1]

    def update(self, action, feedback):
        for i in range(len(action)):
            arm = action[i]
            self.emp_sum[arm] += feedback[i]
            self.te[arm] += 1

    def reset(self):
        self.t = 0
        self.st = 0
        self.te = np.zeros(self.dim)  # the T(e) in the paper: number of observations of arm e
        self.ts = np.zeros(int(self.dim/self.m_size)) # the T(s) : number of observations of super arm after ranking.
        self.emp_sum = np.zeros(self.dim)
        self.num_sets = int(self.dim/self.m_size) # number of sets.
        self.order = []



class CombUCB(Bandit):
    """
    Implementation of CombUCB as a baseline algorithm
    """

    def __init__(self, dim, action_set, m_size):
        super().__init__(action_set, dim, m_size)
        self.t = 0
        self.te = np.zeros(self.dim)  # the T(e): number of observations of arm e
        self.emp_sum = np.zeros(self.dim)
        self.num_sets = int(self.dim/self.m_size) # number of sets.
        print(self.num_sets)
    def next(self, ):
        self.t += 1

        #need to modified
        #if self.t <= int(self.dim + self.m_size - 1) / int(self.m_size):  # m-set: explore all arms in ceil(d/m) rounds
        #    if self.t * self.m_size <= self.dim:
        #         # print(range((self.t - 1) * self.m_size, self.t * self.m_size))
        #        return range((self.t - 1) * self.m_size, self.t * self.m_size)
        #    else:
        #        return range(-self.m_size, 0)
        #else:
        conf_width = np.where(self.te == 0, 100, np.sqrt(1 * np.log(self.t) / self.te))
        emp_avg = np.where(self.te == 0, 0, np.divide(self.emp_sum, self.te))
        lower_conf = emp_avg + conf_width
        return self.oracle(lower_conf)

    def oracle(self, lower_conf):
        # order = np.argsort(lower_conf)  # increasing order
        # print(self.t, order)
        l_max = 0
        rewards_max = 0
        for l in range(self.num_sets):
            rewards_sum = 0
            for o in np.argsort(lower_conf[l*self.m_size:(l+1)*self.m_size])[::-1][0:l+1]:
                rewards_sum = rewards_sum + lower_conf[l*self.m_size+o]
            # print(rewards_sum)
            if rewards_sum > rewards_max:
                rewards_max = rewards_sum
                l_max = l
        
        print(np.argsort(lower_conf[l_max*self.m_size:(l_max+1)*self.m_size])[::-1][0:l_max+1]+ l_max*self.m_size)

        return np.argsort(lower_conf[l_max*self.m_size:(l_max+1)*self.m_size])[::-1][0:l_max+1]+ l_max*self.m_size

    def update(self, action, feedback):
        for i in range(len(action)):
            arm = action[i]
            self.emp_sum[arm] += feedback[i]
            self.te[arm] += 1

    def reset(self):
        self.t = 0
        self.te = np.zeros(self.dim)  # the T(e): number of observations of arm e
        self.emp_sum = np.zeros(self.dim)
        self.num_sets = int(self.dim/self.m_size) # number of sets.