syntax = "proto2";

package hierarchical_ood.protos;

message Main {
    enum Model {
        SOFTMAX = 1;
        ILR = 2;
        HILR = 3;
        CASCADE = 4;
        CASCADEFCHEAD = 5;
        SOFTMAXFCHEAD = 6;
        MOS = 7;
        AMSOFTMAX = 8;
        AMCASCADE = 9;
    }
    optional Model model = 1 [default=SOFTMAX];
    optional string data_dir = 2 [default='data/coarse'];
    optional string hierarchy_fn = 3 [default='pruned-wn.pth'];
    optional float min_norm_factor = 4 [default=0.05];
    optional TrainParams train_params = 5;
    oneof optimizer {
        SGDParams sgd = 6;
        AdamParams adam = 7;
    }
    oneof loss {
        CrossEntropy ce = 8;
        BinaryCrossEntropy bce = 9;
        HierarchicalBinaryCrossEntropy hbce = 10;
        HierarchicalLoss hl = 11;
        MOSLoss ml = 12;
        AMSoftmaxLoss amsl = 13;
        AMCascadeLoss amcl = 14;
    }
    optional int32 seed = 15 [default=1234567];
    optional int32 repeat_iters = 16 [default=1];
    optional bool no_save = 17 [default=false];
    optional bool verbose = 18 [default=false];
    optional string savedir = 19 [default=''];
    optional string backbone = 20 [default='resnet50'];
    oneof model_config {
        Softmax softmax_mc = 21;
        IndependentLogisticRegressors ilr_mc = 22;
        HierarchicalILR hilr_mc = 23;
        Cascade cascade_mc = 24;
        CascadeFCHead cascadefchead_mc = 25;
        SoftmaxFCHead softmaxfchead_mc = 26;
        MOSConf mos_mc = 27;
        AMSoftmax ams_mc = 28;
        AMCascade amc_mc = 29;
    }
    repeated string far_ood_dsets = 30;
    optional string distribution_strategy = 31 [default=''];
    optional bool resume_from_ckpt = 32 [default=false];
    optional string finetune_from_ckpt = 33 [default=''];
    optional bool embed_layer = 34 [default=false];
}

message TrainParams {
    optional int32 epochs = 1 [default=90];
    optional int32 batch_size = 2 [default=64];
    optional string checkpoint_fn = 3 [default=''];
    optional string log_fn = 4 [default=''];
    optional bool freeze_bb = 5 [default=false];
    optional bool freeze_bb_bn = 6 [default=false];
    optional bool bb_pretrained = 7 [default=false];
}

message SGDParams {
    optional float learning_rate = 1 [default=0.1];
    optional float momentum = 2 [default=0.9];
    optional float weight_decay = 3 [default=1e-4];
    optional bool nesterov = 4 [default=false];
    optional int32 warmup_iters = 5 [default=5];
    optional float warmup_factor = 6 [default=0.1];
    optional float lr_decay_factor = 7 [default=0.1];
    // Steps are a percentage of the num_epochs
    // Specifically: int(num_epochs*lr_step) - warmup_iters
    // Default uses 3 steps evenly divided
    // for no steps set to < 0.
    repeated float lr_step = 8 [packed=true];
}

message AdamParams {
    optional float learning_rate = 1 [default=0.1];
    optional float weight_decay = 2 [default=1e-4];
    optional int32 warmup_iters = 3 [default=5];
    optional float warmup_factor = 4 [default=0.1];
    optional float lr_decay_factor = 5 [default=0.1];
    // Steps are a percentage of the num_epochs
    // Specifically: int(num_epochs*lr_step) - warmup_iters
    // Default uses 3 steps evenly divided
    // for no steps set to > 1.
    repeated float lr_step = 6 [packed=true];
}

message Softmax {}
message SoftmaxFCHead {
    repeated int32 fc_head_sizes = 1 [packed=true];
}
message IndependentLogisticRegressors {}
message HierarchicalILR {}
message Cascade {}
message CascadeFCHead {
    repeated int32 fc_head_sizes = 1 [packed=true];
    optional bool split_fchead_layers = 2 [default=true];
}
message MOSConf {
    repeated int32 fc_head_sizes = 1 [packed=true];
}

message AMSoftmax {
    repeated int32 fc_head_sizes = 1 [packed=true];
    optional bool feature_norm = 2 [default=true];
}

message AMCascade {
    repeated int32 fc_head_sizes = 1 [packed=true];
    optional bool feature_norm = 2 [default=true];
    optional bool split_fchead_layers = 3 [default=false];
}

message LossRange {
    optional float start = 1 [default=0.];
    optional float end = 2 [default=0.];
}
message CrossEntropy {}
message BinaryCrossEntropy {}
message HierarchicalBinaryCrossEntropy {}
message HierarchicalLoss {
    optional bool synsetce_loss = 1 [default=true];
    optional bool softpred_loss = 2 [default=false];
    optional bool outlier_exposure = 3 [default=false];
    optional LossRange synsetce_range = 4;
    optional LossRange softpred_range = 5;
    optional LossRange outlier_exposure_range = 6;
    optional bool weight_ce = 7 [default=false];
    optional bool focal_loss = 8 [default=false];
    optional bool depth_weight = 9 [default=false];
}
message MOSLoss {}
message AMSoftmaxLoss {
    optional float s = 1 [default=30.0];
    optional float m = 2 [default=0.4];
}
message AMCascadeLoss {
    optional float s = 1 [default=30.0];
    optional float m = 2 [default=0.5];
    optional float s_factor = 3 [default=1.0];
    optional float m_factor = 4 [default=0.4];
}
// vim:set foldmethod=marker:
