import random
import numpy as np
import params
from sklearn.base import BaseEstimator

class AdversarialLearner(BaseEstimator):

    def __init__(self, u=params.u):
        self.u = u                  # Universe size
        self.VC = params.VC         # VC dimension
        self.t = params.t           # Number of -1 in h0
        self.s = params.s           # Number of points outside sample that AdversarialLearner focus on
        self.gamma = params.gamma   # Advantage of weak learner
        self.hypothesis = None      # Selected hypothesis
        self.hypotheses = [np.array([1]*(u-self.t) + [0]*self.t)]
        self.classes_ = np.array([0, 1])
        for _ in range(1, 2**self.VC):
            self.hypotheses.append(np.array([random.choice([0, 1]) for _ in range(self.u)]))
        self.hypotheses = np.array(self.hypotheses)
    
    def fit(self, X, y, sample_weight=None):
        X = X.flatten().astype(int)
        if sample_weight is None:
            sample_weight = np.full(X.shape, 1/X.size)
        outside = np.setdiff1d(np.array(range(self.u)), X)[:self.s]
        self.hypothesis = self.hypotheses[0]
        adv = 1/2-self.gamma
        for h in self.hypotheses[1:]:
            if h[X]@sample_weight >= 1/2+self.gamma:
                adv2 = sum(h[outside])/outside.size
                if adv2 < adv:
                    self.hypothesis = h
                    adv = adv2
        return self
    
    def predict(self, X):
        return self.hypothesis[X.flatten().astype(int)]

    # Returns mean accuracy on X
    def score(self, X, y):
        return np.mean(self.predict(X) == y)