'''
Code to implement a base class and then a brute force oracle
'''
import numpy as np

class binary_oracle:
    '''
    Class for a generic query oracle that chooses which point to
    sample next
    '''
    def __init__(self):
        self.chosen = 0

    def choose(self):
        '''
        Generic function for how the oracle will choose the next query point
        '''
        raise Exception('choose method not implemented')


class gaussian_width_oracle(binary_oracle):
    '''
    Implementation of Guassian width oracle to query next point
    '''
    def __init__(self, num_samples=100):
        super().__init__()
        self.num_samples = num_samples

    def choose(self, Z, past):
        '''
        Takes in version space Z and 'past' which records which queries have
        already been asked. 
        '''
        best_val = -np.inf 
        self.n, self.d = Z.shape    # Z is [num_models X num_points]
        for i in range(self.d):
            ''' Search over possible points to query '''
            if i not in past:   # if not already queried
                val_1 = self.compute_single(Z, i, label=1) # if true label = 1
                val_neg1 = self.compute_single(Z, i, label=-1) # if true label = - 1
                if min(val_1, val_neg1) > best_val: 
                    best_val = min(val_1, val_neg1)
                    self.chosen = i
        return self.chosen

    def make_eta(self):
        return np.random.multivariate_normal([0]*self.d, np.eye(self.d), 1).flatten()

    def compute_single(self, Z, idx, label=1):
        '''
        Compute Guassian width for a single example in Z if the 
        true label is 'label = {+- 1}'
        '''
        # compute complement of version space
        Z_comp = np.array([row for (j, row) in enumerate(Z) \
                                            if Z[j, idx] == label])
        if not len(Z_comp): return -np.inf
        # compute values of suprema to estimate Guassian width
        sup_vals = [np.max(Z_comp @ self.make_eta()) 
                                for _ in range(self.num_samples)]
        return np.mean(sup_vals)    # return estimated G. width
                











