import numpy as np

class ExperimentPMABFullFeedback:
    def __init__(self, agents, K, operator, sigma):
        self.agents = agents
        self.N = len(agents)
        self.K = K
        self.operator = operator
        self.sigma = sigma
        self.e = 0
        
    def run_epoch(self):
        actions = [a.choose_action(self.e) for a in self.agents]
        action_freq = np.bincount(actions, minlength=self.K) / self.N
        # print(action_freq)
        # print(actions)
        mean_rewards = self.operator(action_freq)
        
        # rewards = [mean_rewards[actions[n]] + np.random.normal(self.K) * self.sigma for n in range(self.N)]
        
        feedbacks = [mean_rewards + np.random.normal(size=self.K) for n in range(self.N)]
        
        for i, a in enumerate(self.agents):
            a.update(self.e, actions[i], feedbacks[i])
        # with Pool(256) as p:
        #     self.agents = p.map(parallel_update_agent, zip(self.agents, 
        #                                                    itertools.repeat(self.e),
        #                                                    actions,
        #                                                    feedbacks))
            
        self.e += 1

        return actions, mean_rewards, feedbacks


class ExperimentPMABBanditFeedback:
    def __init__(self, agents, K, operator, sigma):
        self.agents = agents
        self.N = len(agents)
        self.K = K
        self.operator = operator
        self.sigma = sigma
        self.e = 0

        self._basis = np.eye(K)
        
    def run_epoch(self):
        actions = [a.choose_action(self.e) for a in self.agents]
        action_freq = np.bincount(actions, minlength=self.K) / self.N
        # print(action_freq)
        # print(actions)
        mean_rewards = self.operator(action_freq)
        
        # rewards = [mean_rewards[actions[n]] + np.random.normal(self.K) * self.sigma for n in range(self.N)]
        
        feedbacks = [mean_rewards * self._basis[:,actions[n]] + self.sigma * np.random.normal() * self._basis[:,actions[n]] for n in range(self.N)]
        
        for i, a in enumerate(self.agents):
            a.update(self.e, actions[i], feedbacks[i])
            
        self.e += 1

        return actions, mean_rewards, feedbacks

