from utils import *
from NGDOracle import NGDOracle


 
class GradientTrack(NGDOracle):
    def __init__(self, W, lr_constant, n_workers=10, max_niter=1000, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit', byz_labels = None,
                 random_state=None, agg_method='median'):
        super().__init__(W=W, lr=None, n_workers=n_workers, max_niter=max_niter,
                         tol=tol,coefs_init=coefs_init, coefs_true=coefs_true,
                         model_typ=model_typ,byz_labels=byz_labels, random_state=random_state)
        """
        agg_method: The aggregation uses a robust estimation method, defaulting to the median.
        
        """
        self.lr_constant = lr_constant
        self.agg_method = agg_method

        
    def _initialize_parameters(self,X_star,y_star):
        N, q = X_star.shape
        nn_samples = int(N / self.n_workers)
        n_features = int(q / self.n_workers)
        np.random.seed(self.random_state)
        if self.coefs_init is None:
            coefs_temp = np.random.normal(size=(n_features, 1)) # ensure the same initial value
            self.coefs_,_,_= get_matform(coefs=coefs_temp,n_workers=self.n_workers)
        if self.coefs_true is not None:
            self.coefs_true_star,_,_ = get_matform(coefs=self.coefs_true,n_workers=self.n_workers)
        return self
    
    def _compute_grad(self,X_star,y_star,coefs_star):
        nn_samples = int(len(y_star) / self.n_workers)
        prob = self._get_prob(X_star @ coefs_star)
        residual = y_star - prob
        grad = - (2 / nn_samples) * (X_star.T @ residual)
        return grad
    
    def _update(self,X_star,y_star,WK):
        self.coefs_ = WK@self.coefs_half
        grad_current = self._compute_grad(X_star,y_star,
                                         self.coefs_)
        grad = self.G_star + grad_current - self.grad_past
        self.G_star = WK@grad        
        self.coefs_half = self.coefs_ - self.lr_constant * self.G_star
        self.grad_past = grad_current * 1.0
        return self

    def fit(self, X_star, y_star):
        self._initialize_parameters(X_star, y_star)
        self.n_iter_ = 0
        self.history_score = []
        self.history_score_nbyz = []
        n_features = int( len(self.coefs_) / self.n_workers)
        WK = np.kron(self.W, np.eye(n_features))
        self.coefs_half = self.coefs_ * 1.0
        self.G_star = self._compute_grad(X_star,y_star, self.coefs_)
        self.grad_past = self.G_star * 1.0
        
        for n_iter in range(self.max_niter):
            prev_coefs = self.coefs_ * 1.0
            self._update(X_star,y_star,WK)
            dist = np.max(np.abs(prev_coefs - self.coefs_half))
            self.history_score.append(self._get_mse(self.coefs_,self.coefs_true_star))
            self.history_score_nbyz.append(self._get_mse(self.coefs_,self.coefs_true_star, self.byz_labels))
            self.n_iter_ += 1  
            if dist < self.tol:
                break
        return self
    
    
    
class BTGradientTrack(GradientTrack):
    def __init__(self, W, lr_constant, n_workers=10, max_niter=1000, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit', byz_labels = None,
                 random_state=None, agg_method='median',proportiontocut=0.1):
        super().__init__(W=W, lr_constant=lr_constant,
                         n_workers=n_workers, max_niter=max_niter,
                         tol=tol,coefs_init=coefs_init, coefs_true=coefs_true,
                         model_typ=model_typ, byz_labels=byz_labels, random_state=random_state, agg_method=agg_method)
        self.proportiontocut = proportiontocut
        
    def _update(self,X_star,y_star,WK):
        
        self.A = get_adjacency_mat(self.W,set_diag=True) #### use its own data
        self.coefs_ = aggregate_robust(self.A, self.coefs_half, method=self.agg_method, proportiontocut=self.proportiontocut)
        
        grad_current = self._compute_grad(X_star,y_star, self.coefs_)
        grad = self.G_star + grad_current - self.grad_past
        
        self.G_star = aggregate_robust(self.A, grad, method=self.agg_method,proportiontocut=self.proportiontocut)         
        self.coefs_half = self.coefs_ - self.lr_constant * self.G_star
        self.grad_past = grad_current * 1.0
        return self