import numpy as np

class SyntheticPreferenceEnv:
    """
    Synthetic data environment:
      - play_duel(c, d): Compute p_cd = (mu_on[c]-mu_on[d]+1)/2 based on mu_on, return outcome ~ Bernoulli(p_cd)
      - sample_reward(i): For single arm i, sample reward from N(mu_on[i], 0.1) and clip it to [0,1]
    """
    def __init__(self, mu_on):
        self.mu_on = mu_on
    def play_duel(self, c, d):
        p_cd = 1.0 / (1.0 + np.exp(1*(self.mu_on[d] - self.mu_on[c])))
        outcome = np.random.binomial(1, p_cd)
        return outcome
    def sample_reward(self, i):
        r = np.random.normal(loc=self.mu_on[i], scale = 1)
        return r

class RealDataPreferenceEnv:
    """
    Real-world data environment:
      - play_duel(c, d): For (c, d), there are two feedback modes:
            * "data" mode: Randomly sample 100 ratings for each arm, compute their averages, and compare to return outcome.
            * "bernoulli" mode: Use the given μ_on to compute winning probability p_cd = (μ_on[c]-μ_on[d]+1)/2,
              and sample feedback from a Bernoulli distribution.
      - sample_reward(i): Randomly sample one rating from the corresponding movie's rating list.
    """
    def __init__(self, arm_ids, movie_rewards_dict, mu_on, feedback_mode):
        self.arm_ids = arm_ids
        self.movie_rewards_dict = movie_rewards_dict
        self.mu_on = mu_on
        self.feedback_mode = feedback_mode

    def play_duel(self, c, d):
        mid_c = self.arm_ids[c]
        mid_d = self.arm_ids[d]
        # np.random.seed()
        if self.feedback_mode == "data":
            # check whether has enough data
            if mid_c not in self.movie_rewards_dict or mid_d not in self.movie_rewards_dict:
                print(f"data mode: mid_c: {mid_c}, mid_d: {mid_d}, missing data, using default 0.5------------------------")
                return 1 if 0.5 >= 0.5 else 0  

            ratings_c = self.movie_rewards_dict[mid_c]
            ratings_d = self.movie_rewards_dict[mid_d]
            # print(f"data mode: mid_c: {mid_c}, mid_d: {mid_d}, ratings_c: {len(ratings_c)}, ratings_d: {len(ratings_d)}")

            # sample 3 rating data randomly
            sample_size = 3
            if len(ratings_c) >= sample_size:
                sampled_ratings_c = np.random.choice(ratings_c, size=sample_size, replace=False)
            else:
                sampled_ratings_c = np.random.choice(ratings_c, size=sample_size, replace=True)

            if len(ratings_d) >= sample_size:
                sampled_ratings_d = np.random.choice(ratings_d, size=sample_size, replace=False)
            else:
                sampled_ratings_d = np.random.choice(ratings_d, size=sample_size, replace=True)

            # calculate 
            # avg_rating_c = np.mean(sampled_ratings_c)+ np.random.normal(0, 0.01)
            # avg_rating_d = np.mean(sampled_ratings_d)+ np.random.normal(0, 0.01)
            # print(f"data mode: mid_c: {avg_rating_c}, mid_d: {avg_rating_d}")

            avg_rating_c = np.mean(sampled_ratings_c)
            avg_rating_d = np.mean(sampled_ratings_d)
            
            if avg_rating_c == avg_rating_d:
                print(f"data mode: mid_c: {mid_c}, mid_d: {mid_d}, avg_rating_c: {avg_rating_c:.4f}, avg_rating_d: {avg_rating_d:.4f}, equal means, returning random outcome")
                return np.random.choice([0, 1])
            
            return 1 if avg_rating_c >= avg_rating_d else 0

        elif self.feedback_mode == "bernoulli":
            if self.mu_on is None:
                raise ValueError("mu_on must be provided when using bernoulli feedback mode.")
            p_cd = 1.0 / (1.0 + np.exp(1*(self.mu_on[d] - self.mu_on[c])))
            outcome = np.random.binomial(1, p_cd)
            print(f"bernoulli mode: mu_on[{c}]: {self.mu_on[c]}, mu_on[{d}]: {self.mu_on[d]}, p_cd: {p_cd}, outcome: {outcome}")
            return outcome
        else:
            raise ValueError("Invalid feedback_mode: choose 'data' or 'bernoulli'")

    def sample_reward(self, i):
        mid = self.arm_ids[i]
        return np.random.choice(self.movie_rewards_dict[mid]) if mid in self.movie_rewards_dict else 0.5
