# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from timm.models.layers.helpers import to_2tuple

import clip



__all__ = [
    'clip_vit_base_patch32'
]

class ClipHierVisionTransformer(nn.Module):
    def __init__(self, original_model, nb_classes=[200, 38, 13]):
        super().__init__()
        self.original_model = original_model
        self.embed_dim = original_model.conv1.out_channels  # 768

        self.len_classes = len(nb_classes)
        self.num_classes = nb_classes[0]
        self.num_family = nb_classes[1]
        self.num_manufacturer = nb_classes[2]

        self.head = nn.Linear(self.embed_dim, self.num_classes) 
        self.family_head = nn.Linear(self.embed_dim, self.num_family) 
        self.manufacturer_head = nn.Linear(self.embed_dim, self.num_manufacturer) 
        self.feats_layer = nn.Linear(self.embed_dim, 512)

    def forward(self, pixel_values):
        x = self.original_model.conv1(pixel_values)  # (B, 768, H, W)
        x = x.flatten(2).transpose(1, 2)  # (B, N, 768) - Patch embedding
        x = self.original_model.ln_pre(x)

        for i, block in enumerate(self.original_model.transformer.resblocks):
            x = block(x)
            if i == 7:
                manu_out = self.manufacturer_head(x)
            if i == 9:
                family_out = self.family_head(x)
            if i == 11:
                out = self.head(x)

        x = self.original_model.ln_post(x)
        x = self.feats_layer(x)

        return out.mean(dim=1), family_out.mean(dim=1), manu_out.mean(dim=1), x.mean(dim=1)



@register_model
def clip_vit_base_patch32(clip_model='ViT-B/32', nb_classes=[200, 38, 13]):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, _ = clip.load(clip_model, device)

    visionmodel = ClipHierVisionTransformer(model.visual, nb_classes)
    del model
    return visionmodel.float()

