Trainer¶
-
class
train.
Trainer
(experiment_id, train_loader, test_loader, model, loss, optimizer, epochs)¶ - Description:
This is the main class that is responsbile for training the models. It achieves that through:
- train:
- A function that responsible for doing the training and testing operation. It uses mini-batch training setting.
- Zip results:
- A method that respinsible for zaipping the outputs of the model and the corresponding statistics and upload them to WandB servers.
- Args:
expermient_id: An experiment id for distinguishing the result files for each experiment. train dataloader: A dataloader for the training data. test dataloader: A dataloader for the testing data. model: The model that is need to be trained. loss: A loss function to measure the model’s performance. optimizer: An optimizer to optimize model parameters in the light of the loss function. epochs: Number of training epochs.
-
zip_results
(files)¶ - Description:
- A method to zip the results and upload them to WandB server.
- Return:
- 0 if success, otherwise -1.
- Return type:
- int
- Args:
files: A list of training and testing results (predictions and losses).
-
train
(loss_type='default')¶ - Description:
A method to train the models that are included in this baseline. it has three training settings:
- Baseline: Train the model with the non-noisy labels using MSE loss.
- Cutoff: Train the model with noisy labels that are filtered using CutoffMSE loss.
- BIV: Train the model with noisy labels using BIV loss.
- Return:
- Trained model.
- Return type:
- nn.Module object.
- Args:
loss type: Type of the loss function that is used to train the model.