import torch

def get_data_limits(X):
    data_limits = X.min(dim=0)[0], X.max(dim=0)[0]
    data_limits = torch.stack(data_limits).T
    return data_limits

class Temperature_Scheduler:
    def __init__(self, n_epochs, config):

        self.start = config["start"]
        self.end = config["end"]
        self.progress = config["progress"]
        self.n_epochs = n_epochs
        self.step = (self.end-self.start)/n_epochs
        self.current = self.start

    def get_temperature(self):
        if self.progress == "linear":
            self.current += self.step
        #elif self.progress == "exponential":
        #    self.current *= self.step
        if self.current < self.end:
            return self.end
        return self.current

def replace_feature_names(rule,feature_names):
    # replace X0<=... with feature_names[0]<=...
    for i in reversed(range(len(feature_names))):
        if "X"+str(i) in rule:
            rule = rule.replace("X"+str(i),feature_names[i])
    return rule

def convert(df,var, val):
     if val.replace(".", "").isnumeric():
          return float(val)
     else:
          if str(df.dtypes[var]) == 'bool':
               return float(val=='True')