from abc import ABC, abstractmethod

import numpy as np
from .projections import euclidean_proj_simplex

class Agent(ABC):
    def __init__(self, N, K):
        pass
    
    def choose_action(self, e):
        pass
    
    def update(self, e, a, feedback):
        pass
    

class AgentPAFullFeedback(Agent):
    
    def __init__(self, K, tau):
        # self.N = N
        self.K = K
        self.tau = tau
        self.eta0 = 1 / tau
        
        self.policy = np.ones(K) / K
        
    def choose_action(self, e):
        return np.random.choice(np.arange(self.K), p=self.policy)
    
    def update(self, e, a, feedback):
        eta = self.eta0 / (e+1)
        # print(eta)
        # print(type(self.tau))
        # print(feedback)
        # print(type(self.policy))
        self.policy = euclidean_proj_simplex( (1 - eta*self.tau) * self.policy + eta*feedback)
    
    
class AgentPABanditFeedback(Agent):
    
    def __init__(self, K, tau, epsilon):
        # self.N = N
        self.K = K
        self.tau = tau
        self.eta0 = 1 / tau

        self.epsilon = epsilon
        
        self.exploring = False
        self.explore_count = 0
        self.operator_estimate = np.zeros(K)
        
        self.policy = np.ones(K) / K

        self.time_since_epoch_start = 0
        self.epoch = 0
        
    def choose_action(self, e):
        if np.random.rand() > (self.epsilon):
            self.exploring = False
            return np.random.choice(np.arange(self.K), p=self.policy)
        else:
            self.exploring = True
            return np.random.choice(np.arange(self.K))
    
    def update(self, e, a, feedback):
        self.time_since_epoch_start += 1

        if self.exploring:
            self.explore_count += 1
            self.operator_estimate = ((self.explore_count - 1) / self.explore_count) * self.operator_estimate + (1 / self.explore_count) * feedback * self.K
            
        if self.time_since_epoch_start >= int(np.log(self.epoch + 1) / self.epsilon) + 1:
            eta = self.eta0 / (self.epoch + 1)
            self.policy = euclidean_proj_simplex( (1 - eta*self.tau) * self.policy + eta*self.operator_estimate)
            
            self.operator_estimate = np.zeros(self.K)
            self.explore_count = 0
            self.epoch += 1
            self.time_since_epoch_start = 0