# -*- coding: utf-8 -*-
# final version

import cmath
import sys

import numpy as np

# ---------------------------------- Environment: generate user, item and feedback ----------------------------------------- #


class Environment:
    def __init__(self, d, num_users, num_items=50, arms=None, users=None, type="Stochastic"):
        self.L = num_items
        self.d = d
        self.user_num = num_users
        self.users = users
        self.all_arms = arms
        self.init_users()
        self.type = type

    def init_users(self):
        cluster_num = 10
        users_per_cluster = self.user_num // cluster_num
        for i in range(cluster_num):
            temp = np.zeros((self.d, 1))
            end_idx = (i + 1) * users_per_cluster if i < cluster_num - 1 else self.user_num  # the last cluster
            for j in range(i * users_per_cluster, end_idx):
                temp += self.users[j].theta
            average_theta = temp / (end_idx - i * users_per_cluster)
            for j in range(i * users_per_cluster, end_idx):
                self.users[j].theta = average_theta

    def add_noise(self, fv, noise_level):
        for i in range(self.d):
            while True:
                noise = np.random.normal(0, noise_level)
                if abs(noise) < 1:
                    break
            fv[i] = fv[i] + noise
        return fv

    def get_items(self):
        # randomly generate L items from self.arms and form a matrix
        # get all index of arms
        all_index = range(0, len(self.all_arms))
        selected_index = np.random.choice(all_index, self.L, replace=False)
        items = np.zeros((self.L, self.d))
        if self.type == "Stochastic":
            for i in range(self.L):
                items[i, :] = self.all_arms[selected_index[i]].fv.T
        else:
            for i in range(self.L):
                items[i, :] = self.add_noise(self.all_arms[selected_index[i]].fv, 0.1).T
        # print("shape of items: ", items.shape)
        if len(items) != self.L:
            raise AssertionError("The number of items is not equal to L")
        return items

    # get mean reward, instant reward, and best reward
    def feedback(self, items, i, k):  # k: the chosen item's index , i: user_index
        x = items[k, :]  # select item from item array
        mean_reward = np.dot(x, self.users[i].theta)
        temp_noise = np.random.normal(0, scale=1)
        instant_reward = mean_reward + temp_noise
        best_reward = np.max(np.dot(items, self.users[i].theta))
        return mean_reward, instant_reward, best_reward

    def generate_users(self):  # user selection is uniform
        user = np.random.choice(range(self.user_num))
        return user
