import os
import numpy as np
import torch
from scipy import interpolate

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

#from timm.models import create_model
from .protonet import ProtoNet
from .deploy import ProtoNet_AdaTok, ProtoNet_AdaTok_EntMin
from .niwmeta import NIWMeta

def get_backbone(args):
    if args.arch == 'vit_base_patch16_224_in21k':
        from .vit_google import VisionTransformer, CONFIGS

        config = CONFIGS['ViT-B_16']
        model = VisionTransformer(config, 224)

        url = 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'
        pretrained_weights = 'pretrained_ckpts/vit_base_patch16_224_in21k.npz'
        try:
            import wget
            os.makedirs('pretrained_ckpts', exist_ok=True)
            wget.download(url, pretrained_weights)
        except:
            print(f'Cannot download pretrained weights from {url}')

        model.load_from(np.load(pretrained_weights))
        print('Pretrained weights found at {}'.format(pretrained_weights))

    elif args.arch == 'dino_base_patch16':
        from . import vision_transformer as vit

        model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
        url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
        state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)

        model.load_state_dict(state_dict, strict=True)
        print('Pretrained weights found at {}'.format(url))

    elif args.arch == 'dino_small_patch16':
        from . import vision_transformer as vit

        model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
        url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
        state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)

        model.load_state_dict(state_dict, strict=True)
        print('Pretrained weights found at {}'.format(url))

    elif args.arch == 'beit_base_patch16_224_pt22k':
        from .beit import default_pretrained_model
        model = default_pretrained_model(args)
        print('Pretrained BEiT loaded')

    elif args.arch == 'clip_base_patch16_224':
        from . import clip
        model, _ = clip.load('ViT-B/16', 'cpu')
    elif args.arch == 'dino_xcit_medium_24_p16':
        model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')

    elif args.arch == 'dino_xcit_medium_24_p8':
        model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')

    else:
        raise ValueError(f'{args.arch} is not conisdered in the current code.')

    return model

def get_model(args):
    backbone = get_backbone(args)

    if args.deploy == 'vanilla':
        model = ProtoNet(backbone)
    elif args.deploy == 'niwmeta':
        if args.mixup > 0.:  # smoothing is handled with mixup label transform
            lossfun = SoftTargetCrossEntropy()
        elif args.smoothing:
            lossfun = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
        else:
            lossfun = torch.nn.CrossEntropyLoss()
        model = NIWMeta(backbone, use_gami = args.use_gami, n0_init = args.n0_init, gam0_init = args.gam0_init,
            steps = args.steps, burnin = args.burnin, alp = args.alp, ai_max=args.ai_max, lossfun = lossfun)
    elif args.deploy == 'ada_tokens':
        model = ProtoNet_AdaTok(backbone, args.num_adapters,
                                args.ada_steps, args.ada_lr)
    elif args.deploy == 'ada_tokens_entmin':
        model = ProtoNet_AdaTok_EntMin(backbone, args.num_adapters,
                                       args.ada_steps, args.ada_lr)
    return model