# model settings
_base_ = [
    '10+10.py'
]

model = dict(
    pseudo_label_setting=dict(
        use_mg_pseudo=True,
        alpha=1.,
    )
)

train_dataloader = dict(
    batch_size=16,
    num_workers=16
)

train_stage=2
co_run_after=False
a=20
a_EWC=False
a_gpr=False
a_be_ewc=False
task='10+10'
task_num=1
old_task_ewc_path=f"EWC_final/{task}/grad_squared_mean_prototype_task_{task_num-1}.pkl"
run_before_ewc_path=f"EWC_final/{task}/grad_squared_mean_only_share_nonew_task_{task_num}.pkl"
run_after_ewc_path=f"EWC_final/{task}/grad_squared_mean_prototype_task_{task_num}.pkl"
load_from_weight = './temp_cheakpoints/10_10.pth'
if train_stage==2 or co_run_after or train_stage==3:
    ewc_increase=False
    if a_EWC:
        if train_stage==3:
            load_from="work_dir/EWC_stage2/10+10/fc_cls_transfer_100/epoch_2.pth"
        else:
            load_from="work_dir/EWC/10+10/only_pse/epoch_4.pth"
    elif a_be_ewc:
        if train_stage==3:
            load_from="work_dir/be_EWC/5+5/task_1/EWC_100_lrbackbone=1/epoch_5.pth"
        else:
            load_from="work_dir/be_EWC/5+5/task_1/EWC_100_lrbackbone=1/epoch_5.pth"
    else:
        if train_stage==3:
            load_from="work_dir/EWC/stage2/10+10/only_pse/epoch_4.pth"
        else:
            load_from="work_dir/EWC/10+10/only_pse/epoch_4.pth"
    if co_run_after:
        ewc_grad = True
        resume = False
    else:
        ewc_grad = False
        resume=False
else:
    ewc_increase=True
    load_from=load_from_weight
    ewc_grad = False
    resume = False
if co_run_after:
    train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1,EWC=True)
val_cfg = dict(type='ValLoop',first_task_cls_num=10)
param_scheduler = [
dict(
    type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
    type='MultiStepLR',
    begin=0,
    end=50,
    by_epoch=True,
    milestones=[3,8],  #8
    gamma=0.1)
]

optim_wrapper = dict(
    optimizer=dict(lr=0.03, momentum=0.9, type='SGD', weight_decay=0.0001),
    paramwise_cfg=dict(
        custom_keys={
            'backbone': dict(lr_mult=1.0, decay_mult=1.0),
        },
        norm_decay_mult=0.0),
    type='OptimWrapper')