from utils import *
from NGDOracle import NGDOracle


 

class BRIDGE(NGDOracle):
    def __init__(self, W, lr_constant, n_workers=10, max_niter=1000, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit',byz_labels=None,
                 random_state=None, agg_method='median',proportiontocut=0.1):
        super().__init__(W=W, lr=None, n_workers=n_workers, max_niter=max_niter,
                         tol=tol,coefs_init=coefs_init, coefs_true=coefs_true,
                         model_typ=model_typ, byz_labels=byz_labels, random_state=random_state)
        """
        agg_method: The aggregation uses a robust estimation method, with the default being the median.
        
        """
        self.lr_constant = lr_constant
        self.agg_method = agg_method
        self.proportiontocut = proportiontocut

        
    def _initialize_parameters(self,X_star,y_star):
        N, q = X_star.shape
        nn_samples = int(N / self.n_workers)
        n_features = int(q / self.n_workers)
        np.random.seed(self.random_state)
        if self.coefs_init is None:
            coefs_temp = np.random.normal(size=(n_features, 1)) # ensure the same initial value
            self.coefs_,_,_= get_matform(coefs=coefs_temp,n_workers=self.n_workers)
        if self.coefs_true is not None:
            self.coefs_true_star,_,_ = get_matform(coefs=self.coefs_true,n_workers=self.n_workers)
        return self
    
    def _compute_grad(self,X_star,y_star,coefs_star):
        nn_samples = int(len(y_star) / self.n_workers)
        prob = self._get_prob(X_star @ coefs_star)
        residual = y_star - prob
        grad = - (2 / nn_samples) * (X_star.T @ residual)
        return grad


    def _update(self,X_star,y_star,WK):
        self.A = get_adjacency_mat(self.W,set_diag=True) 
        self.coefs_half = aggregate_robust(self.A,self.coefs_,method=self.agg_method,proportiontocut=self.proportiontocut)
        grad = self._compute_grad(X_star,y_star,
                                         self.coefs_)      
        self.coefs_ = self.coefs_half - self.lr_constant * grad
        return self

    
    def fit(self, X_star, y_star):
        self._initialize_parameters(X_star, y_star)
        self.n_iter_ = 0
        self.history_score = []
        self.history_score_nbyz = []
        n_features = int( len(self.coefs_) / self.n_workers)
        WK = np.kron(self.W, np.eye(n_features))
        self.coefs_half = self.coefs_ * 1.0
        
        for n_iter in range(self.max_niter):
            prev_coefs = self.coefs_ * 1.0
            self._update(X_star,y_star,WK)
            dist = np.max(np.abs(prev_coefs - self.coefs_))
            self.history_score.append(self._get_mse(self.coefs_half,self.coefs_true_star))
            self.history_score_nbyz.append(self._get_mse(self.coefs_half,self.coefs_true_star, self.byz_labels))
            self.n_iter_ += 1  
            if dist < self.tol:
                self.coefs_ = aggregate_robust(self.A,self.coefs_,method=self.agg_method,proportiontocut=self.proportiontocut)
                break
        else: 
            self.coefs_ = aggregate_robust(self.A,self.coefs_,method=self.agg_method,proportiontocut=self.proportiontocut)
        return self