import torch
import torch.nn as nn
from torch.nn import init
import timm
from transformers import CLIPModel
import clip
from transformers import AutoImageProcessor, AutoModel

try:
    from .combined import Noise_Tracker, Dino_Tracker, Dino_Baseline, Dino_Cosine_Tracker, Dino_Pure_Tracker, Dino_Single_Tracker, Dino_Gram_Tracker, Dino_Lora, Dino_Lora_Seg, Dino_Lora_Mid, Dino_Lora_Seg_CNN, Dino_Clip, Clip_Lora_Mid, Clip_Lora, Clip_Lora_QKV, Clip_Lora_Seg_QKV, Clip_Lora_Seg
except:
    from combined import Noise_Tracker, Dino_Tracker, Dino_Baseline, Dino_Cosine_Tracker, Dino_Pure_Tracker, Dino_Single_Tracker, Dino_Gram_Tracker, Dino_Lora, Dino_Lora_Seg, Dino_Lora_Mid, Dino_Lora_Seg_CNN, Dino_Clip, Clip_Lora_Mid, Clip_Lora, Clip_Lora_QKV, Clip_Lora_Seg_QKV, Clip_Lora_Seg


# fc layer weight init
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')  # For old pytorch, you may use kaiming_normal.
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        init.constant_(m.bias.data, 0.0)

    elif classname.find('BatchNorm1d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def init_imagenet_weight(_conv_stem_weight, input_channel=3):
    for i in range(input_channel // 3):
        if i == 0:
            _conv_stem_weight_new = _conv_stem_weight
        else:
            _conv_stem_weight_new = torch.cat([_conv_stem_weight_new, _conv_stem_weight], axis=1)

    return torch.nn.Parameter(_conv_stem_weight_new)


class C2P_CLIP(nn.Module):
    def __init__(self, name='openai/clip-vit-large-patch14', num_classes=1):
        super(C2P_CLIP, self).__init__()
        self.model        = CLIPModel.from_pretrained(name)
        del self.model.text_model
        del self.model.text_projection
        del self.model.logit_scale
        
        self.model.vision_model.requires_grad_(False)
        self.model.visual_projection.requires_grad_(False)
        self.model.fc = nn.Linear( 768, num_classes )
        torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)

    def encode_image(self, img):
        vision_outputs = self.model.vision_model(
            pixel_values=img,
            output_attentions    = self.model.config.output_attentions,
            output_hidden_states = self.model.config.output_hidden_states,
            return_dict          = self.model.config.use_return_dict,      
        )
        pooled_output = vision_outputs[1]  # pooled_output
        image_features = self.model.visual_projection(pooled_output)
        return image_features    

    def forward(self, img):
        # tmp = x; print(f'x: {tmp.shape}, max: {tmp.max()}, min: {tmp.min()}, mean: {tmp.mean()}')
        image_embeds = self.encode_image(img)
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        return self.model.fc(image_embeds)

class CLIPVisual(nn.Module):
    def __init__(self, model_name, num_classes=2, freeze_extractor=True):
        super(CLIPVisual, self).__init__()
        model = CLIPModel.from_pretrained('./pretrained/openai/clip-vit-large-patch14')
        print(f'Successfully loaded CLIP!')
        self.visual_model = model.vision_model
        if freeze_extractor:
            self.freeze(self.visual_model)
        self.fc = nn.Linear(in_features=model.vision_embed_dim, out_features=num_classes, bias=True)

    def forward(self, x):
        x = self.visual_model(x)
        x = self.fc(x[1])

        return x

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False


class CLIPModelV2(nn.Module):
    CHANNELS = {
        "RN50": 1024,  #
        "ViT-B/32": 512,
        "ViT-L/14": 768
    }

    def __init__(self, name='clip-RN50', num_classes=2, freeze_extractor=False):
        super(CLIPModelV2, self).__init__()
        name = name.replace('clip-', '').replace('L-', 'L/').replace('B-', 'B/')
        # self.preprecess will not be used during training, which is handled in Dataset class
        self.model, self.preprocess = clip.load(name, device="cpu")
        print(f'Successfully loaded CLIP!')
        if freeze_extractor:
            self.freeze(self.model)
            print(f'Freezing the feature extractors!')

        self.fc = nn.Linear(self.CHANNELS[name], num_classes)

    def forward(self, x, return_feature=False):
        features = self.model.encode_image(x)
        if return_feature:
            return features
        return self.fc(features)

    @staticmethod
    def freeze(model):
        for param in model.parameters():
            param.requires_grad = False


class ContrastiveModels(nn.Module):
    def __init__(self, model_name, num_classes=2, pretrained=True, embedding_size=1024,
                 freeze_extractor=False):
        super(ContrastiveModels, self).__init__()
        self.model_name = model_name
        self.embedding_size = embedding_size
        self.model = get_models(model_name=model_name, pretrained=pretrained, num_classes=embedding_size,
                                freeze_extractor=freeze_extractor)
        # self.default_cfg = self.model.default_cfg
        self.fc = nn.Linear(embedding_size, num_classes)

    def forward(self, x, return_feature=False):
        feature = self.model(x)
        y_pred = self.fc(feature)
        if return_feature:
            return y_pred, feature

        return y_pred

    def extract_feature(self, x):
        feature = self.model(x)

        return feature


def get_efficientnet_ns(model_name='tf_efficientnet_b3_ns', pretrained=True, num_classes=2, start_down=True):
    """
     # Coefficients:   width,depth,res,dropout
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
        'efficientnet-b8': (2.2, 3.6, 672, 0.5),
        'efficientnet-l2': (4.3, 5.3, 800, 0.5),
    :param model_name:
    :param pretrained:
    :param num_classes:
    :return:
    """
    net = timm.create_model(model_name, pretrained=pretrained)
    if not start_down:
        net.conv_stem.stride = (1, 1)
    n_features = net.classifier.in_features
    net.classifier = nn.Linear(n_features, num_classes)

    return net


def get_swin_transformers(model_name='swin_base_patch4_window7_224', pretrained=True, num_classes=2):
    """
    :param model_name: swin_base_patch4_window12_384   swin_base_patch4_window7_224 swin_base_patch4_window7_224_in22k
    :param pretrained:
    :param num_classes:
    :return:
    """
    net = timm.create_model(model_name, pretrained=pretrained)
    n_features = net.head.in_features
    net.head = nn.Linear(n_features, num_classes)

    return net


def get_convnext(model_name='convnext_base_in22k', pretrained=True, num_classes=2, in_channel=3):
    """
    :param model_name: convnext_base_384_in22ft1k, convnext_base_in22k
    :param pretrained:
    :param num_classes:
    :return:
    """
    net = timm.create_model(model_name, pretrained=pretrained)
    n_features = net.head.fc.in_features
    net.head.fc = nn.Linear(n_features, num_classes)

    if in_channel != 3:
        first_conv_weight = net.stem[0].weight
        first_out_channels = net.stem[0].out_channels
        first_conv = nn.Conv2d(in_channel, first_out_channels, kernel_size=4, stride=4)
        first_conv.weight = init_imagenet_weight(first_conv_weight, input_channel=in_channel)
        net.stem[0] = first_conv

    return net


def get_combined(model_name='resnet50+time', pretrained=True, num_classes=2):
    """
    :param model_name: resnet200d, input_size=512, resnet50
    :param pretrained:
    :param num_classes:
    :return:
    """
    aname, bname = model_name.split('+')
    local_weight_path = f'./pretrained/resnet/resnet50-19c8e357.pth'
    resnet = timm.create_model(aname, pretrained=False)
    resnet.load_state_dict(torch.load(local_weight_path, map_location='cpu'))
    print("Successfully loaded pretrained resnet")
    # n_features = resnet.fc.in_features
    # resnet.fc = nn.Linear(n_features, 512)
    if 'time' in bname:
        default_params = dict(
                    dim=1000, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                )
        transformer = TimeTransformer(num_patches=11, num_classes=num_classes, **default_params)
        net = Noise_Tracker(resnet, transformer)
    else:
        print("no model")

    return net

def get_dino(model_name='dino+time', pretrained=True, num_classes=2):
    """
    :param model_name: resnet200d, input_size=512, resnet50
    :param pretrained:
    :param num_classes:
    :return:
    """
    aname, bname = model_name.split('+')
    if 'not' in aname:
        model = CLIPModel.from_pretrained('/src/pretrained/openai/clip-vit-large-patch14').to('cpu')
        visual_model = model.vision_model
        print("Successfully loaded pretrained clip")
        # print(model)
        dim = model.vision_embed_dim
        if 'time' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=9, num_classes=num_classes, **default_params)
            net = Dino_Tracker(visual_model, transformer)
        elif 'cos' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=9, num_classes=num_classes, **default_params)
            net = Dino_Cosine_Tracker(visual_model, transformer)
        elif 'pure' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=10, num_classes=num_classes, **default_params)
            net = Dino_Pure_Tracker(visual_model, transformer)
        elif 'one' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=9, num_classes=num_classes, **default_params)
            net = Dino_Single_Tracker(visual_model, transformer)
        elif 'gram' in bname:
            net = Dino_Gram_Tracker(visual_model)
        elif 'base' in bname:
            net = Dino_Baseline(visual_model)
        elif 'loramid' in bname:
            net = Clip_Lora_Mid(visual_model, num_classes=num_classes)
        elif 'lorasegqkv' in bname:
            net = Clip_Lora_Seg_QKV(visual_model, num_classes=num_classes)
        elif 'loraseg' in bname:
            net = Clip_Lora_Seg(visual_model, num_classes=num_classes)
        elif 'loraqkv' in bname:
            net = Clip_Lora_QKV(visual_model, num_classes=num_classes)
        elif 'lora' in bname:
            net = Clip_Lora(visual_model, num_classes=num_classes)
        else:
            print("no model")
    else:
        model = AutoModel.from_pretrained('./pretrained/dinov2-large',
                                    local_files_only=True).to('cpu')
        print("Successfully loaded pretrained dino")
        # print(model)
        dim = 1024
        if 'time' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=9, num_classes=num_classes, **default_params)
            net = Dino_Tracker(model, transformer)
        elif 'cos' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=9, num_classes=num_classes, **default_params)
            net = Dino_Cosine_Tracker(model, transformer)
        elif 'pure' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=10, num_classes=num_classes, **default_params)
            net = Dino_Pure_Tracker(model, transformer)
        elif 'one' in bname:
            default_params = dict(
                        dim=dim, depth=3, heads=16, mlp_dim=512, dropout=0.1, emb_dropout=0.1,
                    )
            transformer = TimeTransformer(num_patches=9, num_classes=num_classes, **default_params)
            net = Dino_Single_Tracker(model, transformer)

        elif 'gram' in bname:
            net = Dino_Gram_Tracker(model)
        elif 'base' in bname:
            net = Dino_Baseline(model, num_classes=num_classes)
        elif 'clip' in bname:
            net = Dino_Clip(model, num_classes=num_classes)
        elif 'loramid' in bname:
            net = Dino_Lora_Mid(model, num_classes=num_classes)
        elif 'lorasegcnn' in bname:
            net = Dino_Lora_Seg_CNN(model, num_classes=num_classes)
        elif 'loraseg' in bname:
            net = Dino_Lora_Seg(model, num_classes=num_classes)
        elif 'lora' in bname:
            net = Dino_Lora(model, num_classes=num_classes)
        else:
            print("no model")

    return net


def get_resnet(model_name='resnet200d', pretrained=True, num_classes=2):
    """
    :param model_name: resnet200d, input_size=512, resnet50
    :param pretrained:
    :param num_classes:
    :return:
    """
    net = timm.create_model(model_name, pretrained=False)
    local_weight_path = f'./pretrained/resnet/resnet50-19c8e357.pth'
    net.load_state_dict(torch.load(local_weight_path, map_location='cpu'))
    print("Successfully loaded pretrained resnet")
    n_features = net.fc.in_features
    net.fc = nn.Linear(n_features, num_classes)

    return net

def get_clip_visual_model(model_name="openai/clip-vit-base-patch32", num_classes=2, pretrained=True,
                          freeze_extractor=False):
    if 'openai/clip' in model_name:
        model = CLIPVisual(model_name=model_name, num_classes=num_classes)
    else:
        # 'clip-' + 'name', clip-RN50, clip-ViT-L/14
        model = CLIPModelV2(name=model_name, num_classes=num_classes, freeze_extractor=freeze_extractor)

    return model


def get_models(model_name='tf_efficientnet_b3_ns', pretrained=True, num_classes=2,
               in_channel=3, freeze_extractor=False, embedding_size=None):
    if embedding_size is not None and isinstance(embedding_size, int) and embedding_size > 0:
        model = ContrastiveModels(model_name, num_classes, pretrained, embedding_size, freeze_extractor)
    elif 'dino' in model_name:
        model = get_dino(model_name, pretrained, num_classes)
    elif 'efficientnet' in model_name:
        model = get_efficientnet_ns(model_name, pretrained, num_classes)
    elif 'convnext' in model_name:
        model = get_convnext(model_name, pretrained, num_classes, in_channel=in_channel)
    elif 'swin' in model_name:
        model = get_swin_transformers(model_name, pretrained, num_classes)
    elif 'clip' in model_name:
        model = get_clip_visual_model(model_name, num_classes, freeze_extractor=freeze_extractor)
    elif 'swin' in model_name:
        model = get_swin_transformers(model_name, pretrained=pretrained, num_classes=num_classes)
    elif 'gram' in model_name:  # gram_resnet18
        model = get_GramNet(model_name.replace('gram_', ''))
    elif 'resnet' in model_name:
        model = get_resnet(model_name, pretrained, num_classes)
    elif model_name == 'f3net':
        model = F3Net(num_classes=num_classes, img_width=299, img_height=299, pretrained=pretrained)
    else:
        raise NotImplementedError(model_name)

    return model


if __name__ == '__main__':
    import time
    image_size = 224
    model = get_models(model_name='clip-ViT-L-14', num_classes=2, pretrained=False,
                       embedding_size=512)  # clip-ViT-L-14
    print(model)
    # print(model.default_cfg)
    model = model.to(torch.device('cpu'))
    img = torch.randn(1, 3, image_size, image_size)  # your high resolution picture
    start = time.time()
    times = 1
    for _ in range(times):
        out = model(img)
        if isinstance(out, tuple):
            print([o.shape for o in out])
        else:
            print(out.shape)
    print((time.time()-start)/times)

    # from torchsummary import summary
    # input_s = (3, image_size, image_size)
    # print(summary(model, input_s, device='cpu'))
    pass
