from .hparams import SI_Linear_HParams, SVAG_HParams
from attr import evolve

__all__ = ['SI_Linear_HParams_Dict']


sgd = SI_Linear_HParams(
    arch = 'si_linear_net',
    D = 100,
    loss = 'linear_loss',
    lr=1,
    weight_decay = 0.01,
    momentum=0,
    train_batch_size=100,
    test_batch_size=100,
    train_size=10000,
    test_size=10000,
    task_type = 'regression',
    schedule = [10000,20000,50000],
    lr_decay_factors = [0.1,0.1,0.1],
    wd_decay_factors = [1.,1.,1.],
    max_steps = 60000,
    check_val_every_n_epoch = 100,
    random_flip = False,
    check_rank = False,
    gauss = False,
)

SI_Linear_HParams_Dict = {'sgd':sgd}

fgd = SI_Linear_HParams(
    arch = 'si_linear_net',
    D = 100,
    loss = 'l2_loss',
    lr=0.01,
    weight_decay = 0.01,
    momentum=0,
    train_batch_size=98,
    test_batch_size=1000,
    train_size=98,
    test_size=1000,
    task_type = 'regression',
    max_steps = 60000,
    check_val_every_n_epoch = 100,
    random_flip = False,
    check_rank = False,
    gauss = False,
    cache = 1,
    noise = 0.1
)
SI_Linear_HParams_Dict['fgd'] = fgd 
SI_Linear_HParams_Dict['fgd_0.1_0.02'] = evolve(fgd, weight_decay=0.1, lr=0.02)
SI_Linear_HParams_Dict['fgd_0.1_0.02_6k'] = evolve(SI_Linear_HParams_Dict['fgd_0.1_0.02'], max_steps=6000)
SI_Linear_HParams_Dict['fgd_0.1_0.01_12k'] = evolve(SI_Linear_HParams_Dict['fgd_0.1_0.02_6k'], lr=0.01, max_steps=12000)


SI_Linear_HParams_Dict['lownoise_0.1_0.02_6k'] = evolve(SI_Linear_HParams_Dict['fgd_0.1_0.02_6k'], noise=0.1)

fgd10 = SI_Linear_HParams(
    arch = 'si_linear_net',
    D = 10,
    loss = 'l2_loss',
    lr=0.01,
    weight_decay = 0.05,
    momentum=0,
    train_batch_size=8,
    test_batch_size=1000,
    train_size=8,
    test_size=1000,
    task_type = 'regression',
    max_steps = 60000,
    check_val_every_n_epoch = 1000,
    random_flip = False,
    check_rank = False,
    gauss = False,
    cache = 1,
    noise = 0.01
)
SI_Linear_HParams_Dict['fgd10'] = fgd10
SI_Linear_HParams_Dict['fgd10_15k'] = evolve(fgd10, max_steps=15000)
SI_Linear_HParams_Dict['fgd10_ns0.1_15k'] = evolve(SI_Linear_HParams_Dict['fgd10_15k'], noise=0.1)
SI_Linear_HParams_Dict['fgd10_ns0.1_20k'] = evolve(SI_Linear_HParams_Dict['fgd10_ns0.1_15k'], max_steps=20000)
SI_Linear_HParams_Dict['fgd10_ns1_20k'] = evolve(SI_Linear_HParams_Dict['fgd10_ns0.1_20k'], noise=1, check_val_every_n_epoch=250)
SI_Linear_HParams_Dict['fgd10_ns0.3_20k'] = evolve(SI_Linear_HParams_Dict['fgd10_ns1_20k'], noise=0.3)
SI_Linear_HParams_Dict['fgd10_0.005_ns0.3_40k'] = evolve(SI_Linear_HParams_Dict['fgd10_ns0.3_20k'], lr=0.005, max_steps=40000, check_val_every_n_epoch=500)
SI_Linear_HParams_Dict['fgd10_0.001_ns0.3_250k'] = evolve(SI_Linear_HParams_Dict['fgd10_ns0.3_20k'], lr=0.001, max_steps=250000, check_val_every_n_epoch=1250)
SI_Linear_HParams_Dict['fgd10_0.0001_ns0.3_3M'] = evolve(SI_Linear_HParams_Dict['fgd10_ns0.3_20k'], lr=0.0001, max_steps=3000000, check_val_every_n_epoch=1250)