from utils import *
 
class NGDOracle:
    def __init__(self, W, lr, n_workers=10, max_niter=1000, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit', byz_labels = None, ## add byz_labels
                 random_state=None):
        self.W = W
        self.lr = lr
        self.n_workers = n_workers
        self.max_niter = max_niter
        self.tol = tol
        self.model_typ = model_typ
        self.coefs_init = coefs_init
        self.coefs_true = coefs_true
        self.random_state = random_state
        self.byz_labels = byz_labels

        
    def _initialize_parameters(self,X_star,y_star):
        N, q = X_star.shape
        nn_samples = int(N / self.n_workers)
        n_features = int(q / self.n_workers)
        np.random.seed(self.random_state)
        if self.coefs_init is None:
            coefs_temp = np.random.normal(size=(n_features, 1)) # ensure the same initial value
            self.coefs_,_,_= get_matform(coefs=coefs_temp,n_workers=self.n_workers)
        if self.coefs_true is not None:
            self.coefs_true_star,_,_ = get_matform(coefs=self.coefs_true,n_workers=self.n_workers)
        return self
    
    def _initialize_lr(self,X_star=None,y_star=None):
        n_features = int( len(self.coefs_) / self.n_workers)
        lr_star = np.array([item for item in self.lr for _ in range(n_features)]).reshape(-1, 1)
        self.lr_star = lr_star
        return self

    def fit(self, X_star, y_star):
        self._initialize_parameters(X_star, y_star)
        self._initialize_lr(X_star=X_star, y_star=y_star)
        self.n_iter_ = 0
        self.history_score = []
        self.history_score_nbyz = []
        n_features = int( len(self.coefs_) / self.n_workers)
        WK = np.kron(self.W, np.eye(n_features))
        XWI = X_star @ WK
        for n_iter in range(self.max_niter):
            prev_coefs = self.coefs_ * 1.0
            self._update(X_star,y_star,WK,XWI)   # M 步
            dist = np.max(np.abs(prev_coefs - self.coefs_))
            self.history_score.append(self._get_mse(self.coefs_,self.coefs_true_star))
            self.history_score_nbyz.append(self._get_mse(self.coefs_,self.coefs_true_star, self.byz_labels))
            self.n_iter_ += 1  
            if dist < self.tol:
                break
        return self
    

    def _compute_grad(self,X_star,y_star,XWI,coefs_star):
        nn_samples = int(len(y_star) / self.n_workers)
        prob = self._get_prob(XWI @ coefs_star)
        residual = y_star - prob
        grad = - (2 / nn_samples) * (X_star.T @ residual) 
        return grad
        
    def _update(self,X_star,y_star,WK,XWI):
        grad = self._compute_grad(X_star,y_star,XWI, self.coefs_)
        self.coefs_ = WK @ self.coefs_ - self.lr_star * grad
        return self

    def _get_prob(self, z):
        return get_prob(z,model_typ=self.model_typ)
    
    def _get_mse(self, coefs_est, coefs_true, byz_labels=None):  # add byz_labels 
        """
        Get the MSE 
        """
        if byz_labels is None:  
            mse = np.linalg.norm(coefs_est - coefs_true)**2/self.n_workers
        else:
            coefs_est_list = np.split(coefs_est, self.n_workers)
            coefs_true_list = np.split(coefs_true, self.n_workers)
            mse = np.mean( [np.linalg.norm(coefs_est_list[m] - coefs_true_list[m])**2 for m in range(self.n_workers) if not byz_labels[m] ])
            
        return mse
