# coding=utf-8
# Copyright 2023 The Uncertainty Baselines Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Uncertainty baseline training models."""

from absl import logging
import tensorflow as tf

# ==============================================================================
# Add Vision Transformer, BERT, ed2.mimo, and PyTorch models to their
# corresponding try/except blocks below these main imports, otherwise you will
# break the external build.
# ==============================================================================
from uncertainty_baselines.models import efficientnet_utils
from uncertainty_baselines.models.clip import clip
from uncertainty_baselines.models.criteo_mlp import criteo_mlp
from uncertainty_baselines.models.efficientnet import efficientnet
from uncertainty_baselines.models.efficientnet_batch_ensemble import efficientnet_batch_ensemble
from uncertainty_baselines.models.gat import gat
from uncertainty_baselines.models.genomics_cnn import genomics_cnn
from uncertainty_baselines.models.movielens import movielens
from uncertainty_baselines.models.mpnn import mpnn
from uncertainty_baselines.models.resnet20 import resnet20
from uncertainty_baselines.models.resnet50_batchensemble import resnet101_batchensemble
from uncertainty_baselines.models.resnet50_batchensemble import resnet50_batchensemble
from uncertainty_baselines.models.resnet50_batchensemble import resnet_batchensemble
from uncertainty_baselines.models.resnet50_deterministic import resnet50_deterministic
from uncertainty_baselines.models.resnet50_dropout import resnet50_dropout
from uncertainty_baselines.models.resnet50_fsvi import resnet50_fsvi
from uncertainty_baselines.models.resnet50_het_mimo import resnet50_het_mimo
from uncertainty_baselines.models.resnet50_het_rank1 import resnet50_het_rank1
from uncertainty_baselines.models.resnet50_heteroscedastic import resnet50_heteroscedastic
from uncertainty_baselines.models.resnet50_hetsngp import resnet50_hetsngp
from uncertainty_baselines.models.resnet50_hetsngp import resnet50_hetsngp_add_last_layer
from uncertainty_baselines.models.resnet50_radial import resnet50_radial
from uncertainty_baselines.models.resnet50_rank1 import resnet50_rank1
from uncertainty_baselines.models.resnet50_resizable_width import resnet50_resizable_width
from uncertainty_baselines.models.resnet50_sngp import resnet50_sngp
from uncertainty_baselines.models.resnet50_sngp import resnet50_sngp_add_last_layer
from uncertainty_baselines.models.resnet50_sngp_be import resnet50_sngp_be
from uncertainty_baselines.models.resnet50_tram import resnet50_tram
from uncertainty_baselines.models.resnet50_variational import resnet50_variational
from uncertainty_baselines.models.textcnn import textcnn
from uncertainty_baselines.models.unet import unet
from uncertainty_baselines.models.wide_resnet import get_wide_resnet_hp_keys
from uncertainty_baselines.models.wide_resnet import wide_resnet
from uncertainty_baselines.models.wide_resnet_batchensemble import wide_resnet_batchensemble
from uncertainty_baselines.models.wide_resnet_condconv import wide_resnet_condconv
from uncertainty_baselines.models.wide_resnet_dropout import wide_resnet_dropout
from uncertainty_baselines.models.wide_resnet_heteroscedastic import wide_resnet_heteroscedastic
from uncertainty_baselines.models.wide_resnet_hetsngp import wide_resnet_hetsngp
from uncertainty_baselines.models.wide_resnet_hyperbatchensemble import e_factory as hyperbatchensemble_e_factory
from uncertainty_baselines.models.wide_resnet_hyperbatchensemble import LambdaConfig as HyperBatchEnsembleLambdaConfig
from uncertainty_baselines.models.wide_resnet_hyperbatchensemble import wide_resnet_hyperbatchensemble
from uncertainty_baselines.models.wide_resnet_posterior_network import wide_resnet_posterior_network
from uncertainty_baselines.models.wide_resnet_rank1 import wide_resnet_rank1
from uncertainty_baselines.models.wide_resnet_sngp import wide_resnet_sngp
from uncertainty_baselines.models.wide_resnet_sngp_be import wide_resnet_sngp_be
from uncertainty_baselines.models.wide_resnet_tram import wide_resnet_tram
from uncertainty_baselines.models.wide_resnet_variational import wide_resnet_variational


# pylint: disable=g-import-not-at-top
try:
  # Try to import ViT models.
  from uncertainty_baselines.models import vit_batchensemble
  from uncertainty_baselines.models import vit_batchensemble_gp
  from uncertainty_baselines.models.bit_resnet import bit_resnet
  from uncertainty_baselines.models.bit_resnet_heteroscedastic import bit_resnet_heteroscedastic
  from uncertainty_baselines.models.vit import vision_transformer
  from uncertainty_baselines.models.vit_batchensemble import vision_transformer_be
  from uncertainty_baselines.models.vit_batchensemble_gp import vision_transformer_be_gp
  from uncertainty_baselines.models.vit_gp import vision_transformer_gp
  from uncertainty_baselines.models.vit_hetgpbe import vision_transformer_het_gp_be
  from uncertainty_baselines.models.vit_hetgp import vision_transformer_hetgp
  from uncertainty_baselines.models.vit_mimo import vision_transformer_mimo
  from uncertainty_baselines.models.vit_heteroscedastic import vision_transformer_het
  from uncertainty_baselines.models.vit_tram import vision_transformer_tram
except ImportError:
  logging.warning('Skipped ViT models due to ImportError.', exc_info=True)
except tf.errors.NotFoundError:
  logging.warning('Skipped ViT models due to NotFoundError.', exc_info=True)

try:
  # Try to import Segmenter models.
  from uncertainty_baselines.models.segmenter import SegVit
  from uncertainty_baselines.models.segmenter_be import SegVitBE
  from uncertainty_baselines.models.segmenter_gp import SegVitGP
  from uncertainty_baselines.models.segmenter_heteroscedastic import SegVitHet

except ImportError:
  logging.warning('Skipped Segmenter models due to ImportError.', exc_info=True)
except tf.errors.NotFoundError:
  logging.warning('Skipped Segmenter models due to NotFoundError.',
                  exc_info=True)

try:
  # Try to import models depending on tensorflow_models.official.nlp.
  from uncertainty_baselines.models import bert
  from uncertainty_baselines.models.bert import bert_model
  from uncertainty_baselines.models import bert_dropout
  from uncertainty_baselines.models.bert_dropout import bert_dropout_model
  from uncertainty_baselines.models import bert_sngp
  from uncertainty_baselines.models.bert_sngp import bert_sngp_model
except ImportError:
  logging.warning('Skipped BERT models due to ImportError.', exc_info=True)
except tf.errors.NotFoundError:
  logging.warning('Skipped BERT models due to NotFoundError.', exc_info=True)

try:
  # Try to import models depending on edward2.experimental.mimo.
  from uncertainty_baselines.models.resnet50_mimo import resnet50_mimo
  from uncertainty_baselines.models.wide_resnet_mimo import wide_resnet_mimo
except ImportError:
  logging.warning('Skipped MIMO models due to ImportError.', exc_info=True)
except tf.errors.NotFoundError:
  logging.warning('Skipped MIMO models due to NotFoundError.', exc_info=True)

# pylint: disable=line-too-long
try:
  from uncertainty_baselines.models.resnet50_dropout_torch import resnet50_dropout_torch
  from uncertainty_baselines.models.resnet50_torch import resnet50_torch
except ImportError:
  logging.warning(
      'Skipped Torch ResNet-50 Dropout model due to ImportError.',
      exc_info=True)
# pylint: enable=line-too-long
# pylint: enable=g-import-not-at-top
