import torch
import torch.nn as nn

from torchvision.models import resnet50
from model.timm.models import create_model

from model.trn.crn import TransformerCRN, TransformerCAT

import os    

from model.ge import GenreCrossAttention, Block_CA


def get_shot_encoder(cfg):

    name = cfg.MODEL.shot_encoder.name
    shot_encoder_args = cfg.MODEL.shot_encoder[name]

    assert name in ["resnet", "vit", "vit_x_ge"] ### added vit_x_ge

    if name == "resnet":
        if shot_encoder_args["depth"] == 50:
            shot_encoder = get_resnet50(cfg, shot_encoder_args)
            return shot_encoder
        else:
            raise NotImplementedError

    if name == "vit":
        shot_encoder = get_vit(cfg, shot_encoder_args)

    if name == "vit_x_ge":
        shot_encoder = get_vit_x_ge(cfg, shot_encoder_args)

    if name == "TranS4mer":
        raise NotImplementedError


    # ge_fusion = cfg.MODEL.shot_encoder.ge_fusion
    # ge_path = cfg.MODEL.shot_encoder.ge_path
    
    ge_fusion = cfg.MODEL.shot_encoder.get("ge_fusion", False)
    ge_path = cfg.MODEL.shot_encoder.get("ge_path", None)

    if ge_fusion:
        ge_path = os.path.join(cfg.PROJ_ROOT, ge_path) 
        assert os.path.isfile(ge_path)
        dim = shot_encoder.embed_dim
        # shot_encoder = nn.Sequential(shot_encoder, GenreCrossAttention(dim=dim, ge_path=ge_path))
        shot_encoder = nn.Sequential(shot_encoder, Block_CA(dim=dim, ge_path=ge_path))


    freeze = cfg.MODEL.shot_encoder.get("freeze", False)
    if freeze:
        for parameter in shot_encoder.parameters():
            parameter.requires_grad = False
    
        dim = shot_encoder.embed_dim
        shot_encoder.head = nn.Linear(dim, dim)
    
        for name, param in shot_encoder.named_parameters():
            if not name.find("cross") == -1:
                param.requires_grad = True
    else:
        dim = shot_encoder.embed_dim
        shot_encoder.head = nn.Linear(dim, dim)
    
    
    # print(shot_encoder)
    # raise ValueError
    
    return shot_encoder


def get_contextual_relation_network(cfg):

    crn = None
    name = cfg.MODEL.contextual_relation_network.name
    crn_args = cfg.MODEL.contextual_relation_network.params[name]    

    assert name in ["trn", "cat"]

    attention_mask_type = cfg.MODEL.contextual_relation_network.get("attention_mask_type", "default")
    
    if cfg.MODEL.contextual_relation_network.enabled:
        if name == "trn":
            # crn = get_transform_crn(cfg, crn_args)
            crn = get_transform_crn(cfg, crn_args, attention_mask_type)
        elif name == "cat":
            # crn = get_CAT(cfg, crn_args)
            crn = get_CAT(cfg, crn_args, attention_mask_type)

    # print(crn)
    return crn


################################################################################################


def get_resnet50(cfg, shot_encoder_args):
    
    # https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
    # https://pytorch.org/vision/stable/models.html#initializing-pre-trained-models
    
    if cfg.MODEL.shot_encoder.pretrained:
        assert shot_encoder_args["weights"] in ["IMAGENET1K_V1", "IMAGENET1K_V2"]
        backbone = resnet50(weights=shot_encoder_args["weights"], **shot_encoder_args["params"])
    else:
        backbone = resnet50(**shot_encoder_args["params"])
        
    shot_encoder = backbone
    shot_encoder.fc = nn.Identity() # torch.nn.Linear(2048, 384)

    return shot_encoder


def get_vit(cfg, shot_encoder_args):

    weights = shot_encoder_args["weights"]
    assert weights in ["vit_small_patch32_224", 
                       "vit_base_patch32_224", 
                       "vit_base_patch32_clip_224",
                       "vit_base_patch32_clip_224.laion2b",
                      ] # added vit_base_patch32_224
    # vit_small_patch32_224 : https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L2089
    # vit_base_patch32_clip_224 : https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L2402

    ### CLIP preprocess - https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/constants.py
    ### https://github.com/openai/CLIP/issues/20
    # OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
    # OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
    
    # DEFAULT_CROP_PCT = 0.875 (additional)
    ### https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py#L126
    # scale = 0.8
    
    pretrained = cfg.MODEL.shot_encoder.pretrained
    
    shot_encoder = create_model(weights, pretrained=pretrained)
    # print(shot_encoder)
    # raise ValueError()    
    # shot_encoder.head = nn.Identity()
    if not weights == "vit_base_patch32_clip_224.laion2b":
        shot_encoder.head = nn.Identity()

    return shot_encoder 


def get_vit_x_ge(cfg, shot_encoder_args):

    weights = shot_encoder_args["weights"]
    assert weights in ["vit_base_patch32_clip_224",
                       "vit_base_patch32_clip_224.laion2b",
                       "vit_base_patch32_224",
                       "vit_small_patch32_224", 
                      ]

    ge_path = shot_encoder_args["ge_path"]
    # assert ge_path is not None
    ge_path = os.path.join(cfg.PROJ_ROOT, ge_path)
    # assert os.path.isfile(ge_path)

    wkv = shot_encoder_args.get("wkv", "linear") # shot_encoder_args["wkv"]
    assert wkv in ["linear", "direct"]
    
    pretrained = cfg.MODEL.shot_encoder.pretrained

    ge_type = shot_encoder_args.get("ge_type", "cross_attn")
    assert ge_type in ["cross_attn", "concat"]
    
    shot_encoder = create_model(weights, pretrained=pretrained, ge_path=ge_path, wkv=wkv, ge_type=ge_type)
    # print(shot_encoder)
    # raise ValueError()
    # shot_encoder.head = nn.Identity()
    if not weights == "vit_base_patch32_clip_224.laion2b":
        shot_encoder.head = nn.Identity()

    return shot_encoder


def get_TranS4mer(cfg, shot_encoder_args):
    
    pass


################################################################################################

def get_transform_crn(cfg, crn_args, attention_mask_type):
    
    sampling_name = cfg.LOSS.sampling_method.name
    if sampling_name == "asymmetric":
        neighbor_size = (cfg.LOSS.sampling_method.params[sampling_name]["neighbor_left"] + 
                         cfg.LOSS.sampling_method.params[sampling_name]["neighbor_right"]
                        )
    else:
        neighbor_size = 2 * cfg.LOSS.sampling_method.params[sampling_name]["neighbor_size"]
    crn_args["neighbor_size"] = neighbor_size

    crn = TransformerCRN(crn_args, attention_mask_type)

    return crn
    

def get_CAT(cfg, crn_args, attention_mask_type):

    sampling_name = cfg.LOSS.sampling_method.name
    if sampling_name == "asymmetric":
        neighbor_size = (cfg.LOSS.sampling_method.params[sampling_name]["neighbor_left"] + 
                         cfg.LOSS.sampling_method.params[sampling_name]["neighbor_right"]
                        )
    else:
        neighbor_size = 2 * cfg.LOSS.sampling_method.params[sampling_name]["neighbor_size"]
    crn_args["neighbor_size"] = neighbor_size

    crn = TransformerCAT(crn_args, attention_mask_type)

    return crn

__all__ = ["get_shot_encoder", "get_contextual_relation_network"]