_base_ = [
    '../_base_/models/resnet50.py', '../_base_/datasets/imagenetlt_val_bs32_pil_resize.py',
    '../_base_/schedules/imagenetlt_bs256_coslr.py', '../_base_/default_runtime.py'
]
model = dict(
    backbone=dict(frozen_stages=4),
    head=dict(
        type='TauNormLinearClsHead',
        num_classes=1000, topk=(1,5),
        loss=dict(type='SoftKLDivLoss', loss_weight=0.03),
    )
)
data = dict(
    samples_per_gpu=512, # 128*4=512
    train=dict(soft_file='work_dirs/LT_uni90_resnet50_xkx_imagenetlt_1gpu_b512_lws_cas_phase1_xent/latest.pkl'),
    workers_per_gpu=8,
    sampler=dict(train=dict(
        type='class_aware',
        num_samples_per_category=4,
        soft_file=True,
    ))
)
checkpoint_config = dict(interval=1)
evaluation = dict(interval=1, metric='accuracy')
log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(
            type='WandbLoggerHook',
            init_kwargs=dict(
                project='OTA_DEV',
                name='kldiv_phase2',
            )
        )
    ]
)
runner = dict(type='EpochBasedRunner', max_epochs=1)
load_from='work_dirs/ImageNetLT_resnet50_uniform_e90.pth'