import importlib

import torch
import torch.nn as nn
import torch.distributed as dist
import yacs.config

from .model import *

def select_key(params):
    select_params = {}
    for key in params.keys():
        if 'refine.' in key:
            select_params[key] = params[key]
            # print(f'Key:{key}')
    return select_params

def load_pretrain_model(model, model_path):

    # for name, param in model.named_parameters():
    #     print(f"Name: {name}, Size: {param.size()}")

    loaded_state_dict = torch.load(model_path, map_location='cpu')['model']
    # loaded_state_dict = select_key(checkpoint)

    model_keys = set([k for k in list(model.state_dict().keys())])
    load_keys = set(loaded_state_dict.keys())

    toload = {}
    mismatched_shape_keys = []
    i = 0
    for k in model_keys:
        if k in load_keys:
            if model.state_dict()[k].shape != loaded_state_dict[k].shape:
                mismatched_shape_keys.append(k)
            else:
                toload[k] = loaded_state_dict[k]
                i += 1
                # print(f'From pretrain model, load weight {i}: {k}')
        else:
            toload[k] = model.state_dict()[k]
    assert mismatched_shape_keys==[], 'There are mismatched shape keys, check please!'

    model.load_state_dict(toload)
    print(f'-------------------load the pretrained model! {i} weight from: {model_path}')

    return model


def build_model(cfg, model): # to fixed some patameter here to build final model

    if cfg.model.refine_fea.use_pair_weight:
        model = load_pretrain_model(model, cfg.model.refine_fea.pair_weight_ckpt_path)
    return model


def apply_data_parallel_wrapper(config: yacs.config.CfgNode,
                                model: nn.Module) -> nn.Module:

    local_rank = config.train.dist.local_rank
    print(f'local_rank---------------:{local_rank}')
    if dist.is_available() and dist.is_initialized():
        if config.train.dist.use_sync_bn:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = nn.parallel.DistributedDataParallel(model.cuda(),
                                                    device_ids=[local_rank],
                                                    output_device=local_rank,
                                                    find_unused_parameters=True)
    else:
        model.to(config.device)
    return model