# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
from alg.algs.ERM import ERM
from alg.algs.MMD import MMD
from alg.algs.CORAL import CORAL
from alg.algs.DANN import DANN
from alg.algs.RSC import RSC
from alg.algs.Mixup import Mixup
from alg.algs.MLDG import MLDG
from alg.algs.GroupDRO import GroupDRO
from alg.algs.ANDMask import ANDMask
from alg.algs.VREx import VREx
from alg.algs.DIFEX import DIFEX
from alg.algs.diversify import Diversify
from alg.algs.AdaRNN import AdaRNN
from alg.algs.IRM import IRM
from alg.algs.IIB import IIB
from alg.algs.IB_IRM import IB_IRM

ALGORITHMS = {
    'ERM' : ERM,
    'Mixup': Mixup,
    'CORAL': CORAL,
    'MMD': MMD,
    'DANN': DANN,
    'MLDG': MLDG,
    'GroupDRO': GroupDRO,
    'RSC': RSC,
    'ANDMask': ANDMask,
    'VREx': VREx,
    'DIFEX': DIFEX,
    'Diversify': Diversify,
    'AdaRNN': AdaRNN,
    'IRM': IRM,
    'IIB': IIB,
    'IB_IRM': IB_IRM
}


def get_algorithm_class(algorithm_name):
    if algorithm_name not in ALGORITHMS.keys():
        raise NotImplementedError(
            "Algorithm not found: {}".format(algorithm_name))
    return ALGORITHMS[algorithm_name]




# ALGORITHMS = [
#     'diversify'
# ]


# def get_algorithm_class(algorithm_name):
#     if algorithm_name not in ALGORITHMS:
#         raise NotImplementedError(
#             "Algorithm not found: {}".format(algorithm_name))
#     return Diversify
