_base_ = [
    '../_base_/models/resnet50.py', '../_base_/datasets/imagenetlt_val_bs32_pil_resize.py',
    '../_base_/schedules/imagenetlt_bs256_coslr.py', '../_base_/default_runtime.py'
]
model = dict(
    backbone=dict(frozen_stages=4),
    head=dict(
        type='TauNormLinearClsHead',
        num_classes=1000, topk=(1,5),
        loss=dict(type='SoftCrossEntropyLoss')
    )
)
data = dict(
    samples_per_gpu=512, # 128*4=512
    train=dict(soft_file='work_dirs/LT_uni90_resnet50_xkx_imagenetlt_1gpu_b512_lws_cas_phase2_kldiv/latest.pkl'),
    workers_per_gpu=8,
    sampler=dict(train=dict(
        type='class_aware',
        num_samples_per_category=4,
        soft_file=True,
    ))
)
checkpoint_config = dict(interval=10)
evaluation = dict(interval=1, metric='accuracy')
log_config = dict(
    interval=10,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(
            type='WandbLoggerHook',
            init_kwargs=dict(
                project='OTA_DEV',
                name='xent_phase3',
            )
        )
    ]
)
load_from='work_dirs/ImageNetLT_resnet50_uniform_e90.pth'