import torch

from .factory import create_model_and_transforms
from .get_tokenizer_zh import get_tokenizer_chinese
from .pretrained import get_pretrained_cfg

__all__ = ["load_YouCLIP"]

CFG = {
    'YouCLIP-Base': {'model_name': 'ViT-B-16-SigLIP-CN', 'pretrained_tag': 'webli',
                     'tokenizer_size': 'base', 'context_length': 77},
    'YouCLIP-Base-CN-ENG': {'model_name': 'ViT-B-16-SigLIP-CN', 'pretrained_tag': 'webli',
                            'tokenizer_size': 'base', 'context_length': 77},
    'YouCLIP-Base-512': {'model_name': 'ViT-B-16-SigLIP-512-CN', 'pretrained_tag': 'webli',
                         'tokenizer_size': 'base', 'context_length': 77},
    'YouCLIP-Base-512-CN-ENG': {'model_name': 'ViT-B-16-SigLIP-512-CN', 'pretrained_tag': 'webli',
                                'tokenizer_size': 'base', 'context_length': 77},

    'YouCLIP-Large': {'model_name': 'ViT-L-16-SigLIP-256-CN', 'pretrained_tag': 'webli',
                      'tokenizer_size': 'large', 'context_length': 64},
    'YouCLIP-Large-CN-ENG': {'model_name': 'ViT-L-16-SigLIP-256-CN', 'pretrained_tag': 'webli',
                             'tokenizer_size': 'large', 'context_length': 64},

    'YouCLIP-Huge': {'model_name': 'ViT-SO400M-14-SigLIP-384-CN', 'pretrained_tag': 'webli',
                     'tokenizer_size': 'huge', 'context_length': 64},
    'YouCLIP-Huge-CN-ENG': {'model_name': 'ViT-SO400M-14-SigLIP-384-CN', 'pretrained_tag': 'webli',
                            'tokenizer_size': 'huge', 'context_length': 64},
}


def get_whole_model(open_clip_model_name, open_clip_pretrain_tag, model_state_dict_path, ):
    pretrained_cfg = get_pretrained_cfg(open_clip_model_name, open_clip_pretrain_tag)
    print('pretrained_cfg:{}'.format(pretrained_cfg))
    model, _, preprocess = create_model_and_transforms(open_clip_model_name,
                                                       pretrained=model_state_dict_path,
                                                       image_mean=pretrained_cfg['mean'],
                                                       image_std=pretrained_cfg['std'],
                                                       image_interpolation=pretrained_cfg['interpolation'],
                                                       image_resize_mode=pretrained_cfg['resize_mode'],
                                                       )
    return model, preprocess


def load_CN_model(CN_model_name, open_clip_pretrain_tag, path_model_state_dict):
    model, preprocess = get_whole_model(open_clip_model_name=CN_model_name,
                                        open_clip_pretrain_tag=open_clip_pretrain_tag,
                                        model_state_dict_path=None
                                        )
    model_dict = torch.load(path_model_state_dict, map_location=torch.device('cpu'))
    if 'state_dict' in model_dict:
        model_dict = model_dict['state_dict']
        if next(iter(model_dict.items()))[0].startswith('module'):
            model_dict = {k[len('module.'):]: v for k, v in model_dict.items()}
    model.load_state_dict(model_dict)
    return model, preprocess


def load_YouCLIP(model_name, model_file_path):
    """
    Load pretrained YouCLIP, Now Support 'YouCLIP-Base', 'YouCLIP-Base-CN-ENG', 'YouCLIP-Base-512', 'YouCLIP-Base-512-CN-ENG', 'YouCLIP-Large', 'YouCLIP-Large-CN-ENG', 'YouCLIP-Huge', 'YouCLIP-Huge-CN-ENG'
    Args:
        model_name: name from 'YouCLIP-Base', 'YouCLIP-Base-CN-ENG', 'YouCLIP-Base-512', 'YouCLIP-Base-512-CN-ENG', 'YouCLIP-Large', 'YouCLIP-Large-CN-ENG', 'YouCLIP-Huge', 'YouCLIP-Huge-CN-ENG'........
        model_file_path: pretrained model file path

    Returns:
        a tuple (model, preprocess, tokenizer). model: CLIP model,  preprocess: preprocess for image input, tokenizer_zh: a tokenizer for text input
    """
    assert model_name in CFG, 'Unknown model name: {}, only support:{}'.format(model_name, list(CFG.keys()))
    cur_cfg = CFG[model_name]
    print('load from {}'.format(model_file_path))
    print('model config: {}'.format(cur_cfg))
    tokenizer_zh, vocab_size = get_tokenizer_chinese(context_length=cur_cfg['context_length'],
                                                     size=cur_cfg['tokenizer_size'])
    model, preprocess = load_CN_model(CN_model_name=cur_cfg['model_name'],
                                      open_clip_pretrain_tag=cur_cfg['pretrained_tag'],
                                      path_model_state_dict=model_file_path)
    model.eval()
    return model, preprocess, tokenizer_zh
