from utils import *
from NGDOracle import NGDOracle
import itertools

 

def grid_search(X_star, y_star,X_val_star,y_val_star,
                base_model,coefs_star,
                param_grid,refit=True):
    """
    返回最优参数和设置了最优参数的模型(需重新fit)
    """
    best_loss = np.inf
    losses = []
    scores = []
    
    for params in itertools.product(*param_grid.values()):
        param_dict = dict(zip(param_grid.keys(), params))
        # 动态设置参数到模型
        for key, value in param_dict.items():
            setattr(base_model, key, value)  # 直接设置 base_model 的属性值
        base_model.fit(X_star, y_star)
        _ = base_model.fit_loss(X_val_star,y_val_star)
        loss = base_model.loss_.squeeze()
        losses.append(loss)
        scores.append(base_model.history_score[-1])
        if np.min(loss) < best_loss:
            best_loss = np.min(loss)
            best_params = param_dict.copy()  # 保存网格参数
    if refit:
        for key, value in best_params.items():
            setattr(base_model, key, value)  # 直接设置 base_model 的属性值
    return losses,scores,best_params,base_model


def get_loss(X_star, y_star,coefs,n_workers=10,model_typ='linear'):
    y_hat = get_prob(X_star @ coefs,model_typ=model_typ)
    if model_typ == 'linear':
        loss = (y_star - y_hat)**2 
    if model_typ == 'logit':
        loss = -(y_star * np.log(y_hat) + (1 - y_star) * np.log(1 - y_hat))
        
    avg_loss = grouped_mean(loss, n_workers)   
    return avg_loss.reshape([-1,1])


class NGD(NGDOracle):
    def __init__(self, W, lr_constant, n_workers=10, max_niter=1000, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit', byz_labels = None, random_state=None):
        super().__init__(W=W, lr=None, n_workers=n_workers, max_niter=max_niter,
                         tol=tol,coefs_init=coefs_init, coefs_true=coefs_true,
                         model_typ=model_typ, byz_labels=byz_labels, random_state=random_state)
        """
        lr is a vector
        """
        self.lr_constant = lr_constant
        
    def _initialize_lr(self,X_star=None,y_star=None):
        self.lr_star = self.lr_constant * 1.0
        return self


class AdaptiveNGD(NGD):
    def __init__(self, W, lr_constant, n_workers=10,
                 max_niter=1000, max_nrefit=0, tol=1e-6,
                 coefs_init=None, coefs_true=None, model_typ='logit', byz_labels=None,
                 random_state=None,
                 cn=1,coefs_init_star=None):
        super().__init__(W=W, lr_constant=lr_constant, n_workers=n_workers, max_niter=max_niter,tol=tol,coefs_init=coefs_init, coefs_true=coefs_true, model_typ=model_typ, byz_labels = byz_labels, random_state=random_state)
        self.max_nrefit = max_nrefit
        self.cn = cn
        self.coefs_init_star = coefs_init_star
        
    
    def _initialize_parameters(self,X_star,y_star):
        N, q = X_star.shape
        nn_samples = int(N / self.n_workers)
        n_features = int(q / self.n_workers)
        np.random.seed(self.random_state)
        if self.coefs_init_star is None:
            coefs_temp = np.random.normal(size=(n_features, 1)) # ensure the same initial value
            self.coefs_,_,_= get_matform(coefs=coefs_temp,n_workers=self.n_workers)
        else:
            self.coefs_ = self.coefs_init_star * 1.0
        if self.coefs_true is not None:
            self.coefs_true_star,_,_ = get_matform(coefs=self.coefs_true,n_workers=self.n_workers)
        return self


    
    
        
    def _compute_lr(self,X_star,y_star,coefs_star_init):
        n_workers = self.n_workers
        N,q = X_star.shape
        nn_samples = int(N / n_workers)
        n_features = int(q / n_workers)
        
        grad = self._compute_grad(X_star,y_star,X_star,coefs_star_init)        
        weights =  map_row(grad, process_row, n_features=n_features,
                           cn=self.cn)
        weights =  weights / np.max(weights) # 增加一行
        self.weights = weights.reshape(-1,1)
        self.mean_weights = np.mean(weights)
        self.max_weights = np.max(weights)
        self.lr_star = np.repeat(self.lr_constant * weights, n_features).reshape(-1, 1)
        return self
        
    def _initialize_lr(self,X_star, y_star):          
        n_workers = self.n_workers  
        ngd = NGD(W=self.W, lr_constant=self.lr_constant,
                  n_workers=self.n_workers, max_niter=self.max_niter,
         tol=self.tol, coefs_init=self.coefs_init, coefs_true=self.coefs_true,
         model_typ=self.model_typ, random_state=self.random_state).fit(X_star, y_star)
        coefs_star_init = ngd.coefs_ * 1.0
        self._compute_lr(X_star,y_star,coefs_star_init)
        return self
    
    
    def fit_loss(self, X_star, y_star):
        self.loss_ = self.weights * get_loss(X_star,y_star,self.coefs_,
                               n_workers=self.n_workers,
                               model_typ=self.model_typ)
        self.weights_mean = self.weights * 1.0
        self.weights_squared = (self.weights**2) * 1.0
        self.n_iter_loss = 0
        self.A = get_adjacency_mat(self.W,set_diag=True)
        self.W_self = self.A / self.A.sum(axis=1, keepdims=True) # 同时考虑自己的loss, 避免周期性
        for n_iter in range(self.max_niter):
            prev_loss = self.loss_ * 1.0; prev_w_mean = self.weights_mean * 1.0
            prev_w_squared = self.weights_squared * 1.0
            self.loss_ = self.W_self @ prev_loss
            self.weights_mean = self.W_self @ prev_w_mean
            self.weights_squared = self.W_self @ prev_w_squared
            dist = np.max(np.abs(prev_loss - self.loss_))
            self.n_iter_loss += 1  
            if dist < self.tol:
                break
        nn_samples = int(X_star.shape[0] / self.n_workers)
        weights_squared = self.weights_squared/(self.weights_mean**2)
        ub = 1.64 * np.sqrt(weights_squared/nn_samples)
        self.loss_ /= self.weights_mean
        self.loss_ += ub
        return self 

    
    def refit(self, X_star, y_star):
        n_features = int( len(self.coefs_) / self.n_workers)
        WK = np.kron(self.W, np.eye(n_features))
        XWI = X_star @ WK
        for ii in range(self.max_nrefit):
            # print(ii)
            self._compute_lr(X_star,y_star,self.coefs_ * 1.0)
            for n_iter in range(self.max_niter):
                prev_coefs = self.coefs_ * 1.0
                self._update(X_star,y_star,WK,XWI)   
                dist = np.max(np.abs(prev_coefs - self.coefs_))
                if dist < self.tol:
                    break
                self.history_score.append(self._get_mse(self.coefs_,self.coefs_true_star))
                self.history_score_nbyz.append(self._get_mse(self.coefs_,self.coefs_true_star, self.byz_labels))
                self.n_iter_ += 1  
        return self
    
