import math
from .base import get_delta, ScaledWeightedLoss, ShiftedWeightedLoss
from .weighted import *
from .canonical import *
from .regularizer import *

params = {
    "MNIST":      {
        "CE":   dict(),
        "FL":   dict(q=0.1),
        "MAE":  dict(),
        "AGCE": dict(a=3, q=1.2),
        "AUL":  dict(a=7, q=0.5),
        "AEL":  dict(),
        "GCE":  dict(q=0.9),
        "TCE":  dict(q=2),
        "SCE":  dict(q=0.95),
        "NCES": dict(q=0.1 / math.log(10))
    },
    "CIFAR10":    {
        "CE":      dict(),
        "FL":      dict(q=3),
        "MAE":     dict(),
        "NCEAGCE": dict(a=9.4, q=1.2),
        "AGCE":    dict(a=5.4, q=1.5),
        "AUL":     dict(a=6.1, q=4.8),
        "AEL":     dict(q=5),
        "GCE":     dict(q=0.9),
        "TCE":     dict(q=2),
        "SCE":     dict(q=0.9),
        "NCES":    dict(q=0.1 / math.log(10)),
    },
    "CIFAR100":   {
        "CE":      dict(),
        "FL":      dict(q=3),
        "MAE":     dict(),
        "AGCE":    dict(a=0.1, q=0.1),
        "AUL":     dict(a=2, q=8.7),
        "AEL":     dict(q=0.2),
        "TCE":     dict(q=6),
        "GCE":     dict(q=0.7),
        "SCE":     dict(q=0.15),
        "NCES":    dict(q=0.01 / math.log(100)),
        "NCEAGCE": dict(a=0.1, q=0.1),
    },
    "WEBVISION":  {
        "CE":   dict(),
        "FL":   dict(q=0.1),
        "MAE":  dict(),
        "AGCE": dict(a=3, q=1.2),
        "AUL":  dict(a=7, q=0.5),
        "AEL":  dict(),
        "GCE":  dict(q=0.9),
        "TCE":  dict(q=2),
        "SCE":  dict(q=0.95),
    },
    "CLOTHING1M": {
        "CE":   dict(),
        "FL":   dict(q=0.1),
        "MAE":  dict(),
        "AGCE": dict(a=3, q=1.2),
        "AUL":  dict(a=7, q=0.5),
        "AEL":  dict(),
        "GCE":  dict(q=0.9),
        "TCE":  dict(q=2),
        "SCE":  dict(q=0.95),
    },

}
