# model
num_classes = 65
src_model = dict(type='ResNet', depth=50, num_classes=num_classes)
# src_head = dict(type='BottleNeckMLP', feature_dim=2048, bottleneck_dim=256, num_classes=num_classes,
#                 type1='bn', type2='wn')
tgt_model = dict(type='ResNet', depth=50, num_classes=num_classes, pretrained=True)
tgt_head = dict(type='BottleNeckMLP', feature_dim=2048, bottleneck_dim=256, num_classes=num_classes,
                type1='bn', type2='wn')

# DivideMix hyper-parameters
lam_u, lam_p = 0, 1.0
T_sharpen = 0.5
alpha = 1.0
tau_p = 0.8

# DINE hyper-parameters
ema = 0.6
lam_mix = 1.0
topk = 1

loss = dict(
    train=dict(type='SemiLoss') if lam_u > 0 else dict(type='SmoothCE'),
    test=dict(type='CrossEntropyLoss'),
)

# data
src, tgt = 'A', 'P'  # A: Art, C: Clipart, P: Product, R: Real_World
sample_file = f'./data/office_home/{tgt}_list.txt'
batch_size = 64
num_workers = 4
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
data = dict(
    warmup=dict(
        ds_dict=dict(
            type='SubOfficeHome',
            sample_file=sample_file,
            mode='warmup'
        ),
        trans_dict=dict(
            type='OHMultiView', views='w',
            mean=mean, std=std,
        )
    ),
    eval_train=dict(
        ds_dict=dict(
            type='SubOfficeHome',
            sample_file=sample_file,
            mode='eval_train'
        ),
        trans_dict=dict(
            type='OHMultiView', views='t',
            mean=mean, std=std
        )
    ),
    label=dict(
        ds_dict=dict(
            type='SubOfficeHome',
            sample_file=sample_file,
            mode='label'
        ),
        trans_dict=dict(
            type='OHMultiView', views='ws', 
            mean=mean, std=std
        )
    ),
    unlabel=dict(
        ds_dict=dict(
            type='SubOfficeHome',
            sample_file=sample_file,
            mode='unlabel'
        ),
        trans_dict=dict(
            type='OHMultiView', views='ws',
            mean=mean, std=std
        )
    ),
    test=dict(
        ds_dict=dict(
            type='SubOfficeHome',
            sample_file=sample_file,
            mode='test'
        ),
        trans_dict=dict(
            type='OHMultiView', views='t',
            mean=mean, std=std
        )
    ),
)

# training optimizer & scheduler
epochs = 50
warmup_epochs = 0
rampup_epochs = 5
lr = 0.01
optimizer = dict(type='SGD', lr=lr, momentum=0.9, weight_decay=1e-3, nesterov=True)

# log & save
log_interval = 50
test_interval = 1
pred_interval = 5  # epochs // 10
work_dir = None  # rewritten by args
resume = None
load = f'./checkpoints/office_home/src_{src}/train_src_{src}/best_val.pth'
port = 10001
