import numpy as np
import torch
import random

"""
Global constants
"""
SEED = 42
MAX_STEPS = 500

STATE_DIM = 4
ACTION_DIM = 2

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

class Panel_Env_Reward:
    def __init__(self, link_function = 'BT', expertise=0.01, size = 100):
        self.size = size
        self.link_function = link_function
        self.expertise = expertise

    def batch_preference_from_reward(self, reward_1, reward_2):
        batch_score_diff = np.mean(reward_2) - np.mean(reward_1)

        if self.link_function == 'L':
            prob = np.minimum(np.maximum(self.expertise * batch_score_diff + 0.5, 0), 1)
        elif self.link_function == 'WB':
            prob = np.exp2(- np.exp2(- self.expertise * batch_score_diff))
        elif self.link_function == 'S':
            prob = (np.sign(self.expertise * batch_score_diff) + 1) * 0.5
        else:
            prob = 1 / (1 + np.exp(- self.expertise * batch_score_diff))
        results = np.random.choice([1,0], p=[prob, 1-prob], size=self.size)
        return results, prob

    def individual_preference_from_reward(self, reward_1, reward_2):
        results = []
        probs = []
        batch_size = len(reward_1)
        for score_1, score_2 in zip(reward_1, reward_2):
            score_diff = score_2 - score_1
            if self.link_function == 'L':
                prob = np.minimum(np.maximum(self.expertise * score_diff + 0.5, 0), 1)
            elif self.link_function == 'WB':
                prob = np.exp2(- np.exp2(- self.expertise * score_diff))
            elif self.link_function == 'S':
                prob = (np.sign(self.expertise * score_diff) + 1) * 0.5
            else:
                prob = 1 / (1 + np.exp(- self.expertise * score_diff))
            probs.append(prob)
            # results.append(np.random.choice([1,0], p=[prob, 1-prob], size=int(np.floor(self.size / batch_size))))
            results.append(np.random.choice([1, 0], p=[prob, 1 - prob], size=self.size))
        return results, probs



if __name__ == '__main__':
    print()