import numpy as np

class ZSPO:
    def __init__(self, time_horizon = None, state_size = None, action_size = None,
                 lr = None, batch_size = None, perturbation_dist = None):
        """

        :param time_horizon: num of steps
        :param state_size: num of states
        :param action_size: num of actions
        :param lr: learning rate
        :param batch_size: batch size before one update
        :param perturbation_dist:
        :param link_function: 'BT' or 'WB'
        """
        if time_horizon == None:
            time_horizon = 100
        if lr == None:
            lr = 0.001
        if batch_size == None:
            batch_size = 10
        if perturbation_dist == None:
            perturbation_dist = 0.01
        if state_size == None:
            state_size = 11
        if action_size == None:
            action_size = 2

        self.state_space = np.arange(state_size)
        self.state_size = len(self.state_space)
        self.action_space = np.arange(action_size)
        self.action_size = len(self.action_space)
        self.time_horizon = time_horizon

        self.weight = np.random.randn(self.state_size, self.action_size)
        self.weight_perturbed = self.weight.copy()

        self.learning_rate = lr
        self.batch_size = batch_size
        self.perturbation_dist = perturbation_dist

        self.preference_prob = []
        self.perturbation_vec = None
        self.count = 0

    def select_action(self, state, perturb):
        """
        select action
        :param state: state
        :param perturb: 0 (original weight) or 1 (perturbed weight)
        :return: action:
        :return: action_prob: probability of each action
        """
        if perturb == 0:
            weight_state = self.weight[state, :]
        else:
            weight_state = self.weight_perturbed[state, :]
        exp_weight = np.exp(weight_state - np.max(weight_state))
        action_prob = exp_weight / np.sum(exp_weight)
        action = np.random.choice(self.action_space, p=action_prob)
        return action, action_prob

    def perturb_weight(self):
        self.perturbation_vec= np.random.randn(self.state_size, self.action_size)
        self.weight_perturbed = self.weight + self.perturbation_dist * self.perturbation_vec
        return self.perturbation_vec

    def add_to_preference(self, prob):
        self.preference_prob.append(prob)

    def train(self):
        """

        :param preference_model: 'BT' or ''WB
        :return:
        """
        self.count = self.count + 1
        prob = np.array(self.preference_prob)
        majority = np.sign(np.sum(prob- 0.5))
        self.weight = self.weight + self.learning_rate * majority * self.perturbation_vec
        if (self.count + 1) >= 200:
            if (self.count + 1) % 10 == 0:
                self.learning_rate = 0.99 * self.learning_rate
                self.perturbation_dist = 0.99 * self.perturbation_dist
        return



class ZPG:
    def __init__(self, time_horizon = None, state_size = None, action_size = None,
                 lr = None, batch_size = None, perturbation_dist = None,
                 link_function = None, epsilon = None):
        """

        :param time_horizon: num of steps
        :param state_size: num of states
        :param action_size: num of actions
        :param lr: learning rate
        :param batch_size: batch size before one update
        :param perturbation_dist:
        :param link_function: 'BT' or 'WB'
        """
        if time_horizon == None:
            time_horizon = 100
        if lr == None:
            lr = 0.001
        if batch_size == None:
            batch_size = 10
        if perturbation_dist == None:
            perturbation_dist = 0.01
        if state_size == None:
            state_size = 11
        if action_size == None:
            action_size = 2
        if link_function == None:
            link_function = 'BT'
        if epsilon == None:
            epsilon = 1e-3

        self.state_space = np.arange(state_size)
        self.state_size = len(self.state_space)
        self.action_space = np.arange(action_size)
        self.action_size = len(self.action_space)
        self.time_horizon = time_horizon

        self.weight = np.random.randn(self.state_size, self.action_size)
        self.weight_perturbed = self.weight.copy()

        self.learning_rate = lr
        self.batch_size = batch_size
        self.perturbation_dist = perturbation_dist
        self.epsilon = epsilon

        self.preference_prob = []
        self.perturbation_vec = None

        self.link_function = link_function

    def select_action(self, state, perturb):
        """
        select action
        :param state: state
        :param perturb: 0 (original weight) or 1 (perturbed weight)
        :return: action:
        :return: action_prob: probability of each action
        """
        if perturb == 0:
            weight_state = self.weight[state, :]
        else:
            weight_state = self.weight_perturbed[state, :]
        exp_weight = np.exp(weight_state - np.max(weight_state))
        action_prob = exp_weight / np.sum(exp_weight)
        action = np.random.choice(self.action_space, p=action_prob)
        return action, action_prob

    def perturb_weight(self):
        self.perturbation_vec= np.random.randn(self.state_size, self.action_size)
        perturb_vec_norm = np.sqrt(np.sum(self.perturbation_vec ** 2))
        self.perturbation_vec = self.perturbation_vec / perturb_vec_norm
        self.weight_perturbed = self.weight + self.perturbation_dist * self.perturbation_vec
        return self.perturbation_vec

    def add_to_preference(self, prob):
        self.preference_prob.append(prob)

    def train(self):
        """

        :param preference_model: 'BT' or ''WB
        :return:
        """
        prob = np.array(self.preference_prob)

        if self.link_function == 'WB':
            trimed_preference = np.maximum(np.minimum(prob, np.exp2(1 / np.log2(self.epsilon))), self.epsilon)          # symmetric trim so that there is no bias
            reward_difference = - np.log2(- np.log2(trimed_preference))  # For Weibull
        else:
            trimed_preference = np.maximum(np.minimum(prob, 1 - self.epsilon), self.epsilon)
            reward_difference = np.log(trimed_preference / (1 - trimed_preference))  # For BT

        value_difference = np.average(reward_difference)
        intensity = value_difference * self.action_size * self.state_size / self.perturbation_dist
        self.weight = self.weight + self.learning_rate * intensity * self.perturbation_vec
        return


class ZBCPG:
    def __init__(self, time_horizon = None, state_size = None,
                 action_size = None, lr = None, batch_size = None,
                 perturbation_dist = None, link_function = None, epsilon = None):
        if time_horizon == None:
            time_horizon = 100
        if lr == None:
            lr = 0.001
        if batch_size == None:
            batch_size = 10
        if perturbation_dist == None:
            perturbation_dist = 0.01
        if state_size == None:
            state_size = 11
        if action_size == None:
            action_size = 2
        if link_function == None:
            link_function = 'BT'
        if epsilon == None:
            epsilon = 1e-3

        self.state_space = np.arange(state_size)
        self.state_size = len(self.state_space)
        self.action_space = np.arange(action_size)
        self.action_size = len(self.action_space)
        self.time_horizon = time_horizon

        self.weight = np.random.randn(self.state_size, self.action_size)
        self.weight_perturbed = self.weight.copy()

        self.learning_rate = lr
        self.batch_size = batch_size
        self.perturbation_dist = perturbation_dist
        self.epsilon = epsilon

        self.preference_prob = []
        self.perturbation_vec = None

        self.link_function = link_function

    def select_action(self, state, perturb):
        """
        select action
        :param state: state
        :param perturb: 0 (original weight) or 1 (perturbed weight)
        :return: action:
        :return: action_prob: probability of each action
        """
        if perturb == 0:
            weight_state = self.weight[state, :]
        else:
            weight_state = self.weight_perturbed[state, :]
        exp_weight = np.exp(weight_state - np.max(weight_state))
        action_prob = exp_weight / np.sum(exp_weight)
        action = np.random.choice(self.action_space, p=action_prob)
        return action, action_prob

    def perturb_weight(self, block = None, block_size = None):
        if block_size == None:
            block_size = 5
        if block == None:
            block = 0
        self.perturbation_vec = np.zeros((self.state_size, self.action_size))
        self.perturbation_vec[block * block_size: block * block_size + block_size, :] = np.random.choice([-1, 1], p=[0.5, 0.5], size=(block_size, self.action_size))
        perturb_vec_norm = np.sqrt(np.sum(self.perturbation_vec ** 2))
        self.perturbation_vec = self.perturbation_vec / perturb_vec_norm
        self.weight_perturbed = self.weight + self.perturbation_dist * self.perturbation_vec
        return self.perturbation_vec

    def add_to_preference(self, prob):
        self.preference_prob.append(prob)

    def train(self):
        prob = np.array(self.preference_prob)

        if self.link_function == 'WB':
            trimed_preference = np.maximum(np.minimum(prob, np.exp2(1 / np.log2(self.epsilon))), self.epsilon)          # symmetric trim so that there is no bias
            reward_difference = - np.log2(- np.log2(trimed_preference))  # For Weibull
        else:
            trimed_preference = np.maximum(np.minimum(prob, 1 - self.epsilon), self.epsilon)
            reward_difference = np.log(trimed_preference / (1 - trimed_preference))  # For BT

        value_difference = np.average(reward_difference)
        intensity = value_difference * self.action_size * self.state_size / self.perturbation_dist
        self.weight = self.weight + self.learning_rate * intensity * self.perturbation_vec
        return
