# metrics
from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred, MultiTaskAveragePrecision
from models.variable_width_resnet import __all__ as vwmodels
import torchvision.models as models

algo_log_metrics = {
    'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred),
    'mse': MSE(),
    'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred),
    'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred),
    'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None),
    None: None,
}

process_outputs_functions = {
    'binary_logits_to_pred': binary_logits_to_pred,
    'multiclass_logits_to_pred': multiclass_logits_to_pred,
    None: None,
}

# See transforms.py
transforms = ['image_base', 'image_resize_and_center_crop', 'poverty', 'rxrx1']

cifar_model_names = sorted(f'cifar_{name}' for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

# See models/initializer.py
models = ['resnet18', 'resnet34', 'resnet50', 'wideresnet50'] + vwmodels + cifar_model_names + ['mlp', 'convnet', 'lenet']

# See algorithms/initializer.py
algorithms = ['ERM', 'groupDRO']
ssl_algorithms = ['simclr', 'simsiam']

# See optimizer.py
optimizers = ['SGD', 'Adam']

# See scheduler.py
schedulers = ['linear_schedule_with_warmup', 'cosine_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR', 'MultiStepLR', 'CosineAnnealingLR']

# See losses.py
losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion']
