# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: 'Model Fairness Check'
description: |
  Perform a fairness check on a certain attribute using AIF360 to make sure the model is fair and ethical
metadata:
  annotations: {platform: 'OpenSource'}
inputs:
  - {name: model_id, type: String,                     description: 'Required. Training model ID', default: 'training-dummy'}
  - {name: model_class_file, type: String,             description: 'Required. pytorch model class file', default: 'PyTorchModel.py'}
  - {name: model_class_name, type: String,             description: 'Required. pytorch model class name', default: 'PyTorchModel'}
  - {name: feature_testset_path, type: String,         description: 'Required. Feature test dataset path in the data bucket'}
  - {name: label_testset_path, type: String,           description: 'Required. Label test dataset path in the data bucket'}
  - {name: protected_label_testset_path, type: String, description: 'Required. Protected label test dataset path in the data bucket'}
  - {name: favorable_label, type: String,              description: 'Required. Favorable label for this model predictions'}
  - {name: unfavorable_label, type: String,            description: 'Required. Unfavorable label for this model predictions'}
  - {name: privileged_groups, type: String,            description: 'Required. Privileged feature groups within this model'}
  - {name: unprivileged_groups, type: String,          description: 'Required. Unprivileged feature groups within this model'}
  - {name: data_bucket_name, type: String,             description: 'Optional. Bucket that has the processed data', default: 'training-data'}
  - {name: result_bucket_name, type: String,           description: 'Optional. Bucket that has the training results', default: 'training-result'}
outputs:
  - {name: metric_path, type: String,                  description: 'Path for fairness check output'}
implementation:
  container:
    image: aipipeline/bias-detector:pytorch
    command: ['python']
    args: [
      -u, fairness_check.py,
      --model_id, {inputValue: model_id},
      --model_class_file, {inputValue: model_class_file},
      --model_class_name, {inputValue: model_class_name},
      --feature_testset_path, {inputValue: feature_testset_path},
      --label_testset_path, {inputValue: label_testset_path},
      --protected_label_testset_path, {inputValue: protected_label_testset_path},
      --favorable_label, {inputValue: favorable_label},
      --unfavorable_label, {inputValue: unfavorable_label},
      --privileged_groups, {inputValue: privileged_groups},
      --unprivileged_groups, {inputValue: unprivileged_groups},
      --metric_path, {outputPath: metric_path},
      --data_bucket_name, {inputValue: data_bucket_name},
      --result_bucket_name, {inputValue: result_bucket_name}
    ]
