import torch
from trainer.utils import *
from utils.util import *
def reweight_ntk(self, weight, X_tr, y_tr, X_val, y_val):
    pass


def reweight_feature(self, weight, X_tr, y_tr, X_val, y_val,
                     alpha0=1, recompute=False, ratio=None, max_clip=1, visualize=False, val_feature=False):
    num_classes = int(y_val.max())+1
    tr_features = self.get_features(X_tr) # shape: (N_tr, D)
    if val_feature:
        val_features = X_val
    else:
        val_features = self.get_features(X_val) # or change to get_ntk_features if use ntk
    mean_feature = weighted_mean_features(val_features, y_val) # shape: (1, D)
    tr_features = tr_features - mean_feature
    val_features = val_features - mean_feature

    Kmat = tr_features @ val_features.T    # shape: (N_tr, N_val)
    Kmat, y_val, Kmat_raw = transform_Kmat(Kmat, y_val, num_class=num_classes)
    if visualize:
        K_val = val_features @ val_features.T
        visualize_Kmat(K_val, y_val)

    match = (y_tr[:, None] == y_val[None, :])
    sign_mask = torch.where(match, 
                            torch.tensor(1.0, device=Kmat.device), 
                            torch.tensor(-alpha0, device=Kmat.device))

    # Apply sign to kernel
    Kmat_signed = Kmat * sign_mask
    w_direction = torch.sum(Kmat_signed, dim=1)

    if recompute:
        weight = torch.ones_like(weight, device=weight.device) /2
        weight = update_weight(weight, w_direction, clipmax=max_clip, ratio=ratio)
    else:
        weight = update_weight(weight, w_direction, clipmax=max_clip, ratio=ratio, eta=0.01)

    weight[weight <= 0.5] = 0
    return weight, Kmat, y_tr, Kmat_raw, w_direction




    
