from utils import *
from NGDOracle import NGDOracle

 

def clip(z, tau=1):
    z = np.atleast_1d(z)
    z_norm = np.linalg.norm(z)
    return  z if z_norm <= tau else (z / z_norm * tau) # np.min([1, tau/(np.linalg.norm(z)+1e-8)]) * z # 
    

def clipped_gossip(A, vec, delta_max=0.9):
    
    n_workers = A.shape[0]
    n_features = vec.shape[0] // n_workers

    # Reshape vec to (n_workers, n_features)
    vec_reshaped = vec.reshape(n_workers, n_features)

    # Initialize new_vec
    new_vec = np.zeros_like(vec_reshaped)

    for i in range(n_workers):
        
        # Find neighbors who can transmit data to node i
        neighbors = np.where(A[i, :] > 0)[0]
        # Extract parameter vectors of neighbors
        neighbor_vectors = vec_reshaped[neighbors, :]  

        # compute adaptive tau_i
        distances = np.sort(np.sum((neighbor_vectors - vec_reshaped[i,:])**2, axis=1))
        distances = distances[:np.max([np.floor(len(neighbors) * (1-delta_max)).astype(int),1] )]
        tau_i = np.sqrt(np.mean(distances)) #np.sqrt(np.sum(distances) / len(neighbors) )
                 
        vec_clipped = clip(neighbor_vectors - vec_reshaped[i,:], tau_i) + vec_reshaped[i,:]
    
        new_vec[i,:] = np.mean(vec_clipped, axis = 0)
        # new_vec[i,:] = np.median(vec_clipped, axis = 0)

    # Reshape new_vec back to (n_workers * n_features, 1)
    new_vec = new_vec.reshape(n_workers * n_features, 1)
    return new_vec

    


class ClippedGossip(NGDOracle):
    def __init__(self, W, lr_constant, n_workers=10, max_niter=1000, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit', byz_labels =None, random_state=None, delta_max=0.9, alpha=0.5):
        super().__init__(W=W, lr=None, n_workers=n_workers, max_niter=max_niter,
                         tol=tol,coefs_init=coefs_init, coefs_true=coefs_true,
                         model_typ=model_typ, byz_labels=byz_labels, random_state=random_state)

        self.lr_constant = lr_constant
        self.delta_max = delta_max
        self.alpha = alpha

        
    def _initialize_parameters(self,X_star,y_star):
        N, q = X_star.shape
        nn_samples = int(N / self.n_workers)
        n_features = int(q / self.n_workers)
        np.random.seed(self.random_state)
        if self.coefs_init is None:
            coefs_temp = np.random.normal(size=(n_features, 1)) # ensure the same initial value
            self.coefs_,_,_= get_matform(coefs=coefs_temp,n_workers=self.n_workers)
        if self.coefs_true is not None:
            self.coefs_true_star,_,_ = get_matform(coefs=self.coefs_true,n_workers=self.n_workers)
        return self
    
    def _compute_grad(self,X_star,y_star,coefs_star):
        nn_samples = int(len(y_star) / self.n_workers)
        prob = self._get_prob(X_star @ coefs_star)
        residual = y_star - prob
        grad = - (2 / nn_samples) * (X_star.T @ residual)
        return grad


    def _update(self,X_star,y_star):
        self.A = get_adjacency_mat(self.W,set_diag=True) 
        grad = self._compute_grad(X_star,y_star, self.coefs_)      
        self.momentum = (1- self.alpha)*self.momentum  +  self.alpha * grad 
        self.coefs_half = self.coefs_ - self.lr_constant * self.momentum
        self.coefs_ = clipped_gossip(self.A, self.coefs_half, delta_max=self.delta_max)
        return self

    
    def fit(self, X_star, y_star):
        self._initialize_parameters(X_star, y_star)
        self.n_iter_ = 0
        self.history_score = []
        self.history_score_nbyz = []
        n_features = int( len(self.coefs_) / self.n_workers)
        self.coefs_half = self.coefs_ * 1.0
        self.momentum = self._compute_grad(X_star,y_star, self.coefs_)
        
        for n_iter in range(self.max_niter):
            prev_coefs = self.coefs_ * 1.0
            self._update(X_star,y_star)
            dist = np.max(np.abs(prev_coefs - self.coefs_))
            self.history_score.append(self._get_mse(self.coefs_half,self.coefs_true_star))
            self.history_score_nbyz.append(self._get_mse(self.coefs_half,self.coefs_true_star, self.byz_labels))
            self.n_iter_ += 1  
            if dist < self.tol:
                break
        return self

