from .hparams import CIFAR10_HParams
from attr import evolve

import math
__all__ = ['CIFAR10_HParams_Dict_Testing']

preresnet32 = CIFAR10_HParams(
    arch = 'preresnet',
    depth = 32,
    widths = [16, 32, 64],
    widen_factor = 1,
    norm_method = 'BN',
    lr=0.8,
    weight_decay = 0.0005,
    momentum=0,
    train_batch_size=128,
    test_batch_size=256,
    schedule = [80,250,500],
    lr_decay_factors = [0.1,0.1,0.1],
    wd_decay_factors = [1.,1.,1.],
    max_epochs = 1000,
    fix_last_layer = True,
    bn_affine = False,
    homo = True,
    wb_project = 'cifar10-testing',
    wb_entity = 'si-limit-diffusion',
    drop_last_batch = True,
)

preresnet32_simple = CIFAR10_HParams(
    arch = 'preresnet',
    depth = 32,
    widths = [16, 32, 64],
    widen_factor = 1,
    norm_method = 'BN',
    lr=0.8,
    weight_decay = 0.0005,
    momentum=0,
    train_batch_size=128,
    test_batch_size=256,
    schedule = [300],
    lr_decay_factors = [0.1],
    wd_decay_factors = [1.],
    max_epochs = 300,
    fix_last_layer = True,
    bn_affine = False,
    homo = True,
    wb_project = 'cifar10-testing',
    wb_entity = 'si-limit-diffusion',
    drop_last_batch = True,
)

preresnet32_full = CIFAR10_HParams(
    arch = 'preresnet',
    depth = 32,
    widths = [16, 32, 64],
    widen_factor = 1,
    norm_method = 'BN',
    lr=0.8,
    weight_decay = 0.0005,
    momentum=0,
    train_batch_size=128,
    test_batch_size=256,
    schedule = [200,500],
    lr_decay_factors = [0.1,0.1],
    wd_decay_factors = [1.,1.],
    max_epochs = 600,
    fix_last_layer = True,
    bn_affine = False,
    homo = True,
    wb_project = 'cifar10-testing',
    wb_entity = 'si-limit-diffusion',
    drop_last_batch = True,
)

CIFAR10_HParams_Dict_Testing = {'preresnet32':preresnet32}
CIFAR10_HParams_Dict_Testing['preresnet32_simple'] = preresnet32_simple
CIFAR10_HParams_Dict_Testing['preresnet32_full'] = preresnet32_full
CIFAR10_HParams_Dict_Testing['preresnet32_gf'] = preresnet32_gf = evolve(preresnet32_simple, lr=0.0001, schedule = [3000], max_epochs = 3000)
CIFAR10_HParams_Dict_Testing['preresnet32_gf_1e-3'] = evolve(preresnet32_gf, lr=0.001)

CIFAR10_HParams_Dict_Testing['preresnet32_simple_no_aug'] = preresnet32_simple_no_aug = evolve(preresnet32_simple, hori_flip=False, crop=False)
CIFAR10_HParams_Dict_Testing['preresnet32_gf_no_aug_1e-3'] = evolve(preresnet32_simple_no_aug, lr=1e-3, schedule = [3000], max_epochs = 3000)
CIFAR10_HParams_Dict_Testing['preresnet32_simple_no_aug_3000'] = evolve(preresnet32_simple_no_aug, schedule = [3000], max_epochs = 3000, save_top_k=-1)

CIFAR10_HParams_Dict_Testing['preresnet32_simple_3000'] = evolve(preresnet32_simple, schedule = [3000], max_epochs = 3000, save_top_k=-1)
CIFAR10_HParams_Dict_Testing['preresnet32_gf_1e-3_2000'] = evolve(preresnet32_gf, lr=0.001, schedule = [2000], max_epochs = 2000)

CIFAR10_HParams_Dict_Testing['preresnet32_no_aug'] = evolve(preresnet32, hori_flip=False, crop=False)

CIFAR10_HParams_Dict_Testing['preresnet32_simple_3000_lr0.08_l2'] = evolve(preresnet32_simple, schedule = [3000], max_epochs = 3000, lr=0.08, loss='l2_loss', save_top_k=-1)
CIFAR10_HParams_Dict_Testing['preresnet32_gf_lr0.08_l2'] = evolve(preresnet32_simple, lr=0.001, schedule = [3000], max_epochs = 3000, loss='l2_loss', save_top_k=1)

