
import interference.Denoiser
import anomalie_score
import anomalie_score.Reconstrac_MSE

def get_interference(config, device):
    if  config["model_type"] == "Denoiser":
        if config["backbone_model"] in ["TabNet", "TabDiff", "MambaTab", "TabM", "DDPM"]:
            return interference.Denoiser.interference_fn2(config, device)
        else:
            return interference.Denoiser.interference_fn(config, device)
    else:
        raise Exception("Interfernece not defined for choosen model")
    

def get_anomalie_score(config):
    if config["anomalie_score"] == "mse":
        return anomalie_score.Reconstrac_MSE.anomalie_score_mean_fn()
    if config["anomalie_score"] == "mean_nodiff": 
        return anomalie_score.Reconstrac_MSE.anomalie_score_mean_nodiff_fn()
    elif config["anomalie_score"] == "mse_sum":
        return anomalie_score.Reconstrac_MSE.anomalie_score_sum_fn()
    elif config["anomalie_score"] == "raw":
        return anomalie_score.Reconstrac_MSE.anomalie_score_mean_raw_fn()
    
    else:
        raise Exception("Anomalie score not defined for choosen metric")