
import torch
from .base_weights import BaseWeights

class FixedWeights(BaseWeights):
    def __init__(self, size, normalize=False):
        super(FixedWeights, self).__init__(size, 0)
        self.reset_weights()
        self.normalize = normalize

    def reset_weights(self):
        with torch.no_grad():
            self.weights.fill_(1)
            self.weights.div_(len(self.weights))

    def predict_weights(self, inp, h):
        weights, h = super(FixedWeights, self).predict_weights(inp, h)
        weights = weights.view(-1, 1, 1, 1)
        return weights, h

    def reset_weight(self, idx):
        pass

    def _iterate_search(self, inp, get_loss, it):
        pass

    def _initialize_search(self, inp, get_loss):
        pass
