_base_ = [
    '../_base_/models/resnet18.py', '../_base_/datasets/imagenet_sketch_bs32_pil_resize_autoaug.py',
    '../_base_/schedules/imagenet_bs256_coslr.py', '../_base_/default_runtime.py'
]
model = dict(
    head=dict(
        num_classes=1000,
        loss=dict(type='SoftCrossEntropyLoss'),
        topk=(1,)
    ))
data = dict(
    samples_per_gpu=256,
    train=dict(soft_file='work_dirs/r50_tent_imagenet_sketch.pkl')) # 256*1
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
runner = dict(max_epochs=10)
checkpoint_config = dict(interval=10)
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        dict(
            type='WandbLoggerHook',
            init_kwargs=dict(
                project='OTA_DEV',
                name='xent_phase1',
            )
        )
    ]
)
load_from='work_dirs/resnet18_imagenet.pth'