'''
code to create a simple class that generates data. Should
have ability to query any given point. Should know correct 
answer. Make a base class and then a specific example
'''
import numpy as np

class data_gen:
    '''
    A base class that specific examples can inherit from
    '''
    def __init__(self):
        self.true_model = None

    def generate_model(self):
        '''
        A function to generate a model in {0,1}^n 
        '''
        raise Exception('generate_model method not implemented')

    def query_label(self, x):
        '''
        Use true model to give the label. Eg if
        true_model is a {0,1}^n vector, query entry x. 
        if the model is a half space. return the sign of 
        the inner product. 
        '''
        raise Exception('generate_model method not implemented')

    def is_correct(self, candidate):
        '''
        Return a bool of if candidate model matches truth
        '''
        return np.all(np.isclose(candidate, self.true_model))

class thresholds(data_gen):
    '''
    A class that inherits from data_gen to make a thresholds example
    '''
    def __init__(self, d, cutoff):
        '''
        d dim vector. Below cutoff, label is -1. Above is +1
        '''
        self.d = d 
        self.cutoff = int(cutoff)
        self.generate_model()

    def generate_model(self):
        self.true_model = np.ones((self.d,))
        self.true_model[:self.cutoff] = -1

    def query_label(self, x):
        if not np.any(self.true_model) == None:     # generate if still None
            self.generate_model()
        return self.true_model[x]




