from torch import nn

def get_ml_loss_fn(args, ptoSolver, conf):
    name = args.opt_model
    if name == "mse":
        from method.Models.MSE import MSE

        ModelCalss = MSE
    elif name == "msesum":
        from method.Models.MSE import MSE_Sum

        ModelCalss = MSE_Sum
    elif name == "ce":
        from method.Models.MSE import CE

        ModelCalss = CE
    elif name == "bce":
        from method.Models.MSE import BCE

        ModelCalss = BCE
    elif name == "mae":
        from method.Models.MSE import MAE

        ModelCalss = MAE
    return ModelCalss


def get_smooth_loss_fn(args, ptoSolver, conf):
    from method.Models.MSE import MSE, Huber, LogCosh, SmoothMAE, NONE
    name = args.smooth_term
    if name == "mse":
        ModelCalss = MSE
        smooth_state = True
    elif name == "huber":
        ModelCalss = Huber
        smooth_state = True
    elif name == "logcosh":
        ModelCalss = LogCosh
        smooth_state = True
    elif name == "smoothmae":
        ModelCalss = SmoothMAE
        smooth_state = True
    elif name == 'none':
        ModelCalss = NONE
        smooth_state = False

    loss_dict = {
        **conf["models"][args.opt_model],
        "log_dir": args.log_dir,
        "loss_path": args.loss_path,
    }

    return ModelCalss(ptoSolver, **loss_dict)


def get_loss_fn(args, ptoSolver, conf):
    name = args.opt_model
    if name in ["mse", "msesum", "ce", "bce", "mae"]:
        ModelCalss = get_ml_loss_fn(args, ptoSolver, conf)
    elif name == "dfl":
        from method.Models.MSE import DFL

        ModelCalss = DFL
    elif name == "spo":
        from method.Models.SPO import SPO

        ModelCalss = SPO
    elif name == "pointLTR":
        from method.Models.LTR import pointwiseLTR

        ModelCalss = pointwiseLTR
    elif name == "pairLTR":
        from method.Models.LTR import pairwiseLTR

        ModelCalss = pairwiseLTR
    elif name == "listLTR":
        from method.Models.LTR import listwiseLTR

        ModelCalss = listwiseLTR
    elif name == "qptl":
        from method.Models.QPTL import QPTL

        ModelCalss = QPTL
    elif name == "intopt":
        # from openpto.method.Models.Intopt import Intopt
        ModelCalss = None
    elif name == "nce":
        from method.Models.NCE import NCE

        ModelCalss = NCE
    elif name == "blackboxSolver":
        from method.Models.Blackbox import blackboxSolver

        ModelCalss = blackboxSolver
    elif name == "blackbox":
        from method.Models.Blackbox import subopt_blackbox

        ModelCalss = subopt_blackbox
    elif name == "identitySolver":
        from method.Models.Identity import IdentitySolver

        ModelCalss = IdentitySolver
    elif name == "identity":
        from method.Models.Identity import subopt_Identity

        ModelCalss = subopt_Identity
    elif name == "lodl":
        from method.Models.LODLs import LODL

        ModelCalss = LODL
    elif name == "perturb":
        from method.Models.perturbed import perturbed

        ModelCalss = perturbed
    elif name == "cpLayer":
        from method.Models.cpLayer import cpLayer

        ModelCalss = cpLayer
    else:
        raise LookupError()

    # smooth_term_function, state = get_smooth_loss_fn(args, ptoSolver, conf)
    loss_dict = {
        **conf["models"][args.opt_model],
        "log_dir": args.log_dir,
        "loss_path": args.loss_path,
    }

    return ModelCalss(ptoSolver, **loss_dict)
