import torch
#from .models import *
from torchvision.models.resnet import resnet18, resnet50, wide_resnet50_2, resnet101, wide_resnet101_2
import yaml
import os

def get_optimizer(net, lr, weight_decay=0.0005, momentum=0.9, freeze_level=0):

    if freeze_level == 2:
        optimizer = torch.optim.Adam(net.fc.parameters(), lr[0])
    elif freeze_level == 1:
        para = net.parameters()
        fc_params = list(map(id, net.fc.parameters()))
        conv_params = filter(lambda p: id(p) not in fc_params, para)
        optimizer = torch.optim.Adam([
            {'params': conv_params, 'lr': lr[1]},
            {'params': net.fc.parameters()}], lr=lr[0])
    else:
        optimizer = torch.optim.SGD(net.parameters(), lr[0], weight_decay=weight_decay,
                                    momentum=momentum, nesterov=True)
    return optimizer


def get_backbone(backbone):
    name_list = {'resnet18': resnet18, 'resnet50': resnet50, 'wide_resnet50_2': wide_resnet50_2,
                 'resnet101': resnet101, 'wide_resnet101_2': wide_resnet101_2}
    return name_list.get(backbone)


def read_yaml(conf_file, backbone, dataset):
    with open(conf_file, 'r') as f:
        file_data = f.read()
    file_data = yaml.load(file_data, Loader=yaml.FullLoader)
    conf = file_data['default']['default']
    if file_data.get(backbone) is not None:
        conf_overwrite = file_data.get(backbone)['default']
        for key, item in conf_overwrite.items():
            conf[key] = item
        if file_data.get(backbone).get(dataset) is not None:
            conf_overwrite = file_data.get(backbone)[dataset]
            for key, item in conf_overwrite.items():
                conf[key] = item
    print('-------------conf load-------------')
    for key, item in conf.items():
        print(key, ': ', item)
    return conf

def load_state_dict(model, load):
    pretrained_dict = torch.load(load, map_location='cpu')['state_dict']
    model_dict = model.state_dict()

    for k, v in pretrained_dict.items():
        model_v = model_dict.get(k)
        if model_v is None or model_v.shape != v.shape:
            continue
        model_dict[k] = v

    model.load_state_dict(model_dict)


if __name__ == '__main__':
    pass
