# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.


import sys
import os
import numpy as np
import torch
from torch import nn
import random

def seed_everything(seed=1029):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def set_device(gpu=-1):
    if gpu >= 0 and torch.cuda.is_available():
        device = torch.device("cuda:" + str(gpu))
    else:
        device = torch.device("cpu")   
    return device

def set_optimizer(optimizer):
    if isinstance(optimizer, str):
        if optimizer.lower() == "adam":
            optimizer = "Adam"
        if optimizer.lower() == "adan":
            optimizer = "Adan"
        if optimizer.lower() == "adabelief":
            optimizer = "AdaBelief"
        if optimizer.lower() == "tadam":
            optimizer = "TAdam"
        if optimizer.lower() == "sophia":
            optimizer = "SophiaG"
        if optimizer.lower() == "sgd":
            optimizer = "SGD"
        if optimizer.lower() == "adagrad":
            optimizer = "Adagrad"
        if optimizer.lower() == "adadelta":
            optimizer = "Adadelta"
        if optimizer.lower() == "rmsprop":
            optimizer = "RMSprop"
        if optimizer.lower() == "cadam":
            optimizer = "CAdam"
        if optimizer.lower() == "sparse_cadam":
            optimizer = "SparseCAdam"
        if optimizer.lower() == "sparse_adam":
            optimizer = "SparseAdam"
        if optimizer.lower() == "yogi":
            optimizer = "Yogi"
        if optimizer.lower() == "cyogi":
            optimizer = "CYogi"
        if optimizer.lower() == "radam":
            optimizer = "RAdam"
        if optimizer.lower() == "cradam":
            optimizer = "CRAdam"
        return getattr(torch.optim, optimizer)

def set_loss(loss):
    if isinstance(loss, str):
        if loss in ["bce", "binary_crossentropy", "binary_cross_entropy"]:
            loss = "binary_cross_entropy"
        else:
            raise NotImplementedError("loss={} is not supported.".format(loss))
    return loss

def set_regularizer(reg):
    reg_pair = [] # of tuples (p_norm, weight)
    if isinstance(reg, float):
        reg_pair.append((2, reg))
    elif isinstance(reg, str):
        try:
            if reg.startswith("l1(") or reg.startswith("l2("):
                reg_pair.append((int(reg[1]), float(reg.rstrip(")").split("(")[-1])))
            elif reg.startswith("l1_l2"):
                l1_reg, l2_reg = reg.rstrip(")").split("(")[-1].split(",")
                reg_pair.append((1, float(l1_reg)))
                reg_pair.append((2, float(l2_reg)))
            else:
                raise NotImplementedError
        except:
            raise NotImplementedError("regularizer={} is not supported.".format(reg))
    return reg_pair

def set_activation(activation):
    if isinstance(activation, str):
        if activation.lower() == "relu":
            return nn.ReLU()
        elif activation.lower() == "sigmoid":
            return nn.Sigmoid()
        elif activation.lower() == "tanh":
            return nn.Tanh()
        else:
            return getattr(nn, activation)()
    else:
        return activation




