from http.client import UnimplementedFileMode
from sys import implementation
from typing import Tuple, Union
import torch
from torch import nn, einsum
from modules.clip_model import CLIP, LayerNorm, Transformer, MultiframeIntegrationTransformer
from einops import rearrange, repeat
from torchvision.transforms import Normalize, Compose, Resize, CenterCrop
from modules.discretesupport import DiscreteSupport
import peft

class TimeRewarder(CLIP):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int, 
                 # video
                 T=8, 
                 mit_layers=1,
                 bin_num=-1,
                 progressor=False,
                 oneframe=False,
                 ):
        super().__init__(
            embed_dim,
            image_resolution, vision_layers, vision_width, vision_patch_size,
            context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
        )
        
        self.T = T
        vision_heads = vision_width // 64
        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.cache_text_features = {}

        self.visual = VisionTransformer(
            input_resolution=image_resolution,
            patch_size=vision_patch_size,
            width=vision_width,
            layers=vision_layers,
            heads=vision_heads,
            output_dim=embed_dim
        )

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.order_attn = CrossAttention(query_dim=embed_dim, context_dim=768, heads=8, dim_head=64, dropout=0.0)

        self.order_ln = LayerNorm(embed_dim)

        self.mit = MultiframeIntegrationTransformer(T=T, embed_dim=embed_dim, layers=mit_layers,)

        self.bin_num = bin_num

        self.progressor = progressor

        self.oneframe = oneframe


        if self.progressor:
            self.order_mlp = nn.Sequential(
                LayerNorm(3 * embed_dim),
                nn.Linear(3 * embed_dim, 3),
            )

        elif self.oneframe:
            self.order_mlp = nn.Sequential(
                LayerNorm(embed_dim),
                nn.Linear(embed_dim, bin_num + 2),
            )
            self.discrete_support = DiscreteSupport(bins=bin_num)

        elif bin_num == -1:
            self.order_mlp = nn.Sequential(
                LayerNorm(2 * embed_dim),
                nn.Linear(2 * embed_dim, 3),
            )
        else:
            self.order_mlp = nn.Sequential(
                LayerNorm(2 * embed_dim),
                nn.Linear(2 * embed_dim, bin_num + 2),
            )
            self.discrete_support = DiscreteSupport(bins=bin_num)
            
        self.initialize_parameters()

        self.img_norm = Compose([
            # Resize((256, 256)),
            # CenterCrop(224),
            Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                                    std=torch.tensor([0.229, 0.224, 0.225]))
        ])

        for param in self.parameters():
            param.requires_grad = False

    
    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'positional_embedding'}

    def encode_text(self, text):
        x = self.token_embedding(text)
        eos_indx = text.argmax(dim=-1)
        K, N1, C = x.shape

        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection
        x = x.reshape(K, -1)
        return x

    def encode_video_with_text(self, image, text, no_text=False):
        b,t,c,h,w = image.size()
        image = image.reshape(-1,c,h,w)
        # print('image:', image.shape)
        cls_features, img_features = self.visual(image)
        if no_text:
            # print('cls_features:', cls_features.shape, ' img_features:', img_features.shape)
            return cls_features
        # b * num_clips * t
        text_img_features = self.order_attn(text, context=img_features)
        text_img_features = text_img_features.reshape(b, t, text_img_features.shape[-2], text_img_features.shape[-1])
        # b * num_clips, t, num_text, 512
        text_img_features = text_img_features.permute(0, 2, 1, 3)
        # b * num_clips, num_text, t, 512
        text_img_features = text_img_features.reshape(-1, t, text_img_features.shape[-1])
        video_features_with_text = self.mit(self.order_ln(text_img_features))
        # b * num_clips * num_text, 512
        
        return video_features_with_text

    def cache_text(self, text):
        self.eval()
        text_number = text.shape[0]
        with torch.no_grad():
            if text_number not in self.cache_text_features.keys():
                chunk_size = 2048
                text_chunks = text.split(chunk_size)
                text_features = []
                for chunk in text_chunks:
                    text_features.append(self.encode_text(chunk))
                self.cache_text_features[text_number] = torch.cat(text_features)
        self.train()
        return self.cache_text_features[text_number]
    
    def get_score_neighbor(self, image, text_feature):
        if len(image.shape) == 4:
            image_1 = torch.roll(image, 1, 0)
            image_1[0] = image[0]
            image_2 = torch.roll(image_1, 1, 0)
            image_2[0] = image_1[0]
            image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
        n, t, c, h, w = image.shape
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)
        video_features = video_features.reshape(n, -1)
        video_features_neighbor = torch.roll(video_features, 5, 0)
        # video_features_neighbor[0] = video_features[0]
        video_features_neighbor[:5] = video_features[0]

        inorder_features = torch.cat((video_features_neighbor, video_features), dim=-1)
        inverse_features = torch.cat((video_features, video_features_neighbor), dim=-1)
        cat_features = torch.cat((inorder_features, inverse_features), dim=0)
        logits = self.order_mlp(cat_features)
        logits = logits.softmax(dim=-1)
        logits = logits[:n,0]

        return logits
    
    def forward_diff(self, image, text_feature, maxnum=99):
        if len(image.shape) == 4:
            image_1 = torch.roll(image, 1, 0)
            image_1[0] = image[0]
            image_2 = torch.roll(image_1, 1, 0)
            image_2[0] = image_1[0]
            image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
        n, t, c, h, w = image.shape
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)
        video_features = video_features.reshape(n, -1)

        powers = []
        val = 1
        while val <= n and len(powers) <= maxnum:
            powers.append(val)
            val *= 2

        rolled_feature_num = len(powers) - 1
        rolled_features = []
        for shift in powers:
            rolled = torch.roll(video_features, shifts=shift, dims=0)
            rolled[:shift] = video_features[0]
            rolled_features.append(rolled)

        neighbor_features = rolled_features[0]
        rolled_features = rolled_features[1:]
        first_clip_features = torch.cat(rolled_features, dim=0)
        first_clip_features = first_clip_features.repeat(2, 1)
        
        second_clip_features = video_features.repeat(rolled_feature_num, 1)
        second_clip_features_neighbor = neighbor_features.repeat(rolled_feature_num, 1)
        second_clip_features = torch.cat((second_clip_features, second_clip_features_neighbor), dim=0)
        
        clip_pair_features = torch.cat((first_clip_features, second_clip_features), dim=-1)
        logits = self.order_mlp(clip_pair_features)

        ori_logits = logits[:,2]
        ori_logits = ori_logits.reshape(rolled_feature_num * 2, n)
        
        logits = logits[:rolled_feature_num * n,2] - logits[rolled_feature_num * n:,2]
        logits = logits.reshape(rolled_feature_num, n)

        ori_logits = torch.cat((ori_logits, logits), dim=0)
        reward = logits.sum(dim=0)

        return ori_logits, reward

    
    def forward_power(self, image, text_feature, maxnum=99):
        if len(image.shape) == 4:
            image_1 = torch.roll(image, 1, 0)
            image_1[0] = image[0]
            image_2 = torch.roll(image_1, 1, 0)
            image_2[0] = image_1[0]
            image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
        n, t, c, h, w = image.shape
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)
        video_features = video_features.reshape(n, -1)
        
        powers = []
        val = 1
        while val <= n and len(powers) < maxnum:
            powers.append(val)
            val *= 2

        rolled_feature_num = len(powers)
        rolled_features = []
        for shift in powers:
            rolled = torch.roll(video_features, shifts=shift, dims=0)
            rolled[:shift] = video_features[0]
            rolled_features.append(rolled)

        first_clip_features = torch.cat(rolled_features, dim=0)
        second_clip_features = video_features.repeat(rolled_feature_num, 1)
        clip_pair_features = torch.cat((first_clip_features, second_clip_features), dim=-1)
        inverse_pair_features = torch.cat((second_clip_features, first_clip_features), dim=-1)
        all_pair_features = torch.cat((clip_pair_features, inverse_pair_features), dim=0)
        logits = self.order_mlp(all_pair_features)
        logits = logits[:rolled_feature_num * n,2] - logits[rolled_feature_num * n:,2]
        logits = logits.reshape(rolled_feature_num, n)
        weights = torch.tensor([1 / (2 ** i) for i in range(rolled_feature_num)], device=logits.device).unsqueeze(1)
        reward = torch.sum(logits * weights, dim=0)
        
        return logits, reward

    def get_image_features(self, image, text_feature):
        if len(image.shape) == 4:
            image = image.unsqueeze(1)
        n, t, c, h, w = image.shape
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        image_features = self.encode_video_with_text(image, text_feature)
        image_features = image_features.reshape(n, -1)
        return image_features

    def get_progressor_score(self, image_features, goal_feature):
        init_feature = image_features[0].expand(image_features.shape[0], -1)
        goal_feature = goal_feature.expand(image_features.shape[0], -1)
        cat_feature = torch.cat((init_feature, image_features, goal_feature), dim=-1)
        logits = self.order_mlp(cat_feature)[:,2]
        return logits

    def forward(self, image, text_feature, progresspredict=True, oneframe=False):
        if len(image.shape) == 4:
            if self.T == 3:
                image_1 = torch.roll(image, 1, 0)
                image_1[0] = image[0]
                image_2 = torch.roll(image_1, 1, 0)
                image_2[0] = image_1[0]
                image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
                # image (n, t, c, h, w);text feature (1, f)            
            else:
                image = image.unsqueeze(1)
        n, t, c, h, w = image.shape
        # print('image.shape:', image.shape)
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)
        video_features = video_features.reshape(n, -1)
        if oneframe:
            logits = self.order_mlp(video_features)
            logits = logits[:, 2:]
            logits = self.discrete_support.vector_to_scalar(logits)
            # repeat to (3,n)
            logits = logits[:,0].unsqueeze(0).expand(3, -1)
            return logits


        init_features = video_features[0].unsqueeze(0).expand(n, -1)
        half_features = video_features[:(n+1)//2]
        half_features = torch.repeat_interleave(half_features, 2, dim=0)[:n]
        neighbor_features = torch.roll(video_features, 1, 0)
        neighbor_features[0] = video_features[0]

        first_clip_features = torch.cat((init_features, neighbor_features, half_features), dim=0)
        second_clip_features = video_features.repeat(3, 1)
        clip_pair_features = torch.cat((first_clip_features, second_clip_features), dim=-1)
        inverse_pair_features = torch.cat((second_clip_features, first_clip_features), dim=-1)
        all_pair_features = torch.cat((clip_pair_features, inverse_pair_features), dim=0)
        logits = self.order_mlp(all_pair_features)
        if progresspredict:
            logits = logits[:3 * n, 2:]
            if self.bin_num > 0:
                logits = self.discrete_support.vector_to_scalar(logits)
        else:
            logits = logits.softmax(dim=-1)
            logits = logits[:3 * n,0] - logits[3 * n:,0]
        logits = logits.reshape(3, n)
        return logits


    def visualize(self, image, text_feature, num_clips=10, progresspredict=True):
        n, t, c, h, w = image.shape
        sample_id = torch.linspace(0, n - 1, num_clips).long()
        image = image[sample_id]
        image = image.reshape(-1, c, h, w)
        image = self.img_norm(image)
        image = image.reshape(num_clips, t, c, h, w)
        text_feature = text_feature.unsqueeze(0).expand(num_clips * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)

        video_features_expanded_1 = video_features.unsqueeze(1).expand(-1, num_clips, -1).reshape(num_clips*num_clips, -1)
        video_features_expanded_2 = video_features.repeat(num_clips, 1)
        video_feature_pairs = torch.cat((video_features_expanded_1, video_features_expanded_2), dim=-1)
        logits = self.order_mlp(video_feature_pairs).softmax(dim=-1)
        if progresspredict:
            logits_matrix = logits[:, 2].view(num_clips, num_clips)
        else:
            logits_matrix = logits[:, 0].view(num_clips, num_clips)
        
        return logits_matrix



    def forward_3class(self, image, text, label_id=None, use_order=False, num_clips=4, label_id_neg=None, one_for_class=False, use_reverse=True, mix_order=False):
        if self.use_cache:
            text_features = self.cache_text(text)
        else:
            text_features = self.encode_text(text)
        
        if label_id_neg is not None and not self.two_head:
            b, neg_num = label_id_neg.shape
            label_id_neg = label_id_neg.view(-1)
            text_features = text_features[label_id_neg]
            text_features = text_features.view(b, neg_num, -1)
        else:
            raise NotImplementedError('label_id_neg is required and two_head model is not supported')

        # b * num_clips * (num_clips - 1)//2, 512
        
        if len(image.shape) == 5: # whole video as input
            b, t_all, c, h, w = image.shape
            step = 2
            image_unfold = image.unfold(1, self.T, step).permute(0, 1, 5, 2, 3, 4)
            sample_id = torch.randperm(image_unfold.shape[1])[:num_clips]
            image_unfold = image_unfold[:, sample_id]
            # (b, num_clips, t, c, h, w)
            image = image_unfold.reshape(b * num_clips, self.T, c, h, w)
        else:
            b, num_clips, t, c, h, w = image.shape
            image = image.reshape(b * num_clips, t, c, h, w)

        if use_reverse:
            image_reverse = torch.flip(image, [1])
        
        b, n, f = text_features.shape
        text_features = text_features.unsqueeze(1).expand(b, 2 * num_clips * self.T, n, f).reshape(b * 2 * num_clips * self.T, n, f)
        # 2 * b * num_clips, num_text, 512
        if use_reverse:
            image_batch = torch.cat((image, image_reverse), dim=0)
            video_features_batch = self.encode_video_with_text(image_batch, text_features)
            video_features = video_features_batch[:b * num_clips]
            video_features_reverse = video_features_batch[b * num_clips:]
            # b * num_clips * num_text, 512
            video_features = video_features.reshape(b, num_clips, n, f)
            video_features_reverse = video_features_reverse.reshape(b, num_clips, n, f)
        else:
            video_features = self.encode_video_with_text(image, text_features)
            video_features = video_features.reshape(b, num_clips, n, f)
        for i in range(num_clips):
            for j in range(i+1, num_clips):
                video_features_1 = video_features[:, i]
                video_features_2 = video_features[:, j]
                if use_reverse:
                    video_features_inverse_1 = video_features_reverse[:, i]
                    video_features_inverse_2 = video_features_reverse[:, j]
                if i == 0 and j == 1:
                    video_features_1_batch = video_features_1
                    video_features_2_batch = video_features_2
                    if use_reverse:
                        video_features_inverse_1_batch = video_features_inverse_1
                        video_features_inverse_2_batch = video_features_inverse_2
                else:
                    video_features_1_batch = torch.cat((video_features_1_batch, video_features_1), dim=0)
                    video_features_2_batch = torch.cat((video_features_2_batch, video_features_2), dim=0)
                    if use_reverse:
                        video_features_inverse_1_batch = torch.cat((video_features_inverse_1_batch, video_features_inverse_1), dim=0)
                        video_features_inverse_2_batch = torch.cat((video_features_inverse_2_batch, video_features_inverse_2), dim=0)
                    # b * n(n-1)/2, num_text, 512
        video_feature_batch_inorder = torch.cat((video_features_1_batch, video_features_2_batch), dim= -1)
        video_feature_batch_inverse = torch.cat((video_features_2_batch, video_features_1_batch), dim= -1)
        if use_reverse:
            video_feature_inverse_batch_inorder = torch.cat((video_features_inverse_1_batch, video_features_inverse_2_batch), dim= -1)
            video_feature_inverse_batch_inverse = torch.cat((video_features_inverse_2_batch, video_features_inverse_1_batch), dim= -1)
            batch_inverse_image = torch.cat((video_feature_inverse_batch_inorder[:, 0], video_feature_inverse_batch_inverse[:, 0]), dim=0)
            if mix_order:
                video_feature_mixorder_batch = torch.cat((video_features_1_batch, video_features_inverse_2_batch), dim= -1)
                video_feature_inverse_mixorder_batch = torch.cat((video_features_inverse_1_batch, video_features_2_batch), dim= -1)
                batch_inverse_image = torch.cat((video_feature_inverse_batch_inorder[:, 0], video_feature_inverse_batch_inverse[:, 0], video_feature_mixorder_batch[:, 0], video_feature_inverse_mixorder_batch[:, 0]), dim=0)
        # b * n(n-1)/2, num_text, 1024
        inorder_batch_correct_text = video_feature_batch_inorder[:, 0]
        inverse_batch_correct_text = video_feature_batch_inverse[:, 0]
        # b * n(n-1)/2, 1024
        if neg_num == 1:
            neg_features = torch.cat((inverse_batch_correct_text, batch_inverse_image), dim=0)
            sampled_neg_features = neg_features[torch.randperm(neg_features.size(0))[:inorder_batch_correct_text.size(0)]]
            cat_features_all = torch.cat((inorder_batch_correct_text, sampled_neg_features), dim=0)
            cat_logits_all = self.order_mlp(cat_features_all)
            logits_order_inorder = cat_logits_all[:inorder_batch_correct_text.shape[0]]
            logits_order_inverse = cat_logits_all[inorder_batch_correct_text.shape[0]:]
            logits_order_inverse = logits_order_inverse.index_select(1, torch.tensor([1, 0, 2]).to(logits_order_inverse.device))
            logits_order = torch.cat((logits_order_inorder, logits_order_inverse), dim=0)
            
            return None, logits_order

        inorder_batch_wrong_text = video_feature_batch_inorder[:, 1:].reshape(-1, video_feature_batch_inorder.shape[-1])
        inverse_batch_wrong_text = video_feature_batch_inverse[:, 1:].reshape(-1, video_feature_batch_inverse.shape[-1])
        batch_wrong_text = torch.cat((inorder_batch_wrong_text, inverse_batch_wrong_text), dim=0)

        cat_features_all = torch.cat((inorder_batch_correct_text, inverse_batch_correct_text, batch_wrong_text), dim=0)
        cat_logits_all = self.order_mlp(cat_features_all)
        logits_order_inorder = cat_logits_all[:inorder_batch_correct_text.shape[0]]
        logits_order_inverse = cat_logits_all[inorder_batch_correct_text.shape[0]:2 * inorder_batch_correct_text.shape[0]]
        logits_order_wrong = cat_logits_all[2 * inorder_batch_correct_text.shape[0]:]

        logits_order_inverse = logits_order_inverse.index_select(1, torch.tensor([1, 0, 2]).to(logits_order_inverse.device))
        logits_order = torch.cat((logits_order_inorder, logits_order_inverse), dim=0)

        return logits_order_wrong, logits_order
    
    def forward_progresspred(self, image, text, label_id=None, use_order=False, num_clips=4, label_id_neg=None, progress=None, one_for_class=False, use_reverse=True, mix_order=False):
        if self.use_cache:
            text_features = self.cache_text(text)
        else:
            text_features = self.encode_text(text)
        
        if label_id_neg is not None and not self.two_head:
            b, neg_num = label_id_neg.shape
            label_id_neg = label_id_neg.view(-1)
            text_features = text_features[label_id_neg]
            text_features = text_features.view(b, neg_num, -1)
        else:
            raise NotImplementedError('label_id_neg is required and two_head model is not supported')

        
        if len(image.shape) < 6:
            raise ValueError('lenth of image shape should be 6, current shape:', image.shape)
        b, num_clips, t, c, h, w = image.shape
        image = image.reshape(b * num_clips, t, c, h, w)
        b, num_clips = progress.shape
        
        b, n, f = text_features.shape
        text_features = text_features.unsqueeze(1).expand(b, num_clips * self.T, n, f).reshape(b * num_clips * self.T, n, f)
        video_features = self.encode_video_with_text(image, text_features)
        video_features = video_features.reshape(b, num_clips, n, f)
        for i in range(num_clips):
            for j in range(i+1, num_clips):
                video_features_1 = video_features[:, i]
                video_features_2 = video_features[:, j]
                progress_interval = progress[:, j] - progress[:, i]
                if i == 0 and j == 1:
                    video_features_1_batch = video_features_1
                    video_features_2_batch = video_features_2
                    progress_label = progress_interval
                else:
                    video_features_1_batch = torch.cat((video_features_1_batch, video_features_1), dim=0)
                    video_features_2_batch = torch.cat((video_features_2_batch, video_features_2), dim=0)
                    progress_label = torch.cat((progress_label, progress_interval), dim=0)
                    # b * n(n-1)/2, num_text, 512
        video_feature_batch_inorder = torch.cat((video_features_1_batch, video_features_2_batch), dim= -1)
        video_feature_batch_inverse = torch.cat((video_features_2_batch, video_features_1_batch), dim= -1)
         # b * n(n-1)/2, num_text, 1024
        inorder_batch_correct_text = video_feature_batch_inorder[:, 0]
        inverse_batch_correct_text = video_feature_batch_inverse[:, 0]
        # b * n(n-1)/2, 1024
        if neg_num == 1:
            cat_features_all = torch.cat((inorder_batch_correct_text, inverse_batch_correct_text), dim=0)
            cat_logits_all = self.order_mlp(cat_features_all)
            logits_order = cat_logits_all
            
            # Classification loss
            classification_labels = torch.zeros(cat_features_all.shape[0]).long().to(logits_order.device)
            classification_loss = nn.CrossEntropyLoss()(logits_order[:, :2], classification_labels)
            
            # Regression loss
            regression_logits = logits_order[:, 2]
            regression_labels = torch.cat((progress_label, - progress_label), dim=0).to(regression_logits.device)
            regression_loss = nn.MSELoss()(regression_logits, regression_labels)
            
            # Combined loss
            combined_loss = classification_loss + 0.5 * regression_loss
            
            return combined_loss, classification_loss, regression_loss, logits_order

        inorder_batch_wrong_text = video_feature_batch_inorder[:, 1:].reshape(-1, video_feature_batch_inorder.shape[-1])
        inverse_batch_wrong_text = video_feature_batch_inverse[:, 1:].reshape(-1, video_feature_batch_inverse.shape[-1])
        batch_wrong_text = torch.cat((inorder_batch_wrong_text, inverse_batch_wrong_text), dim=0)

        cat_features_all = torch.cat((inorder_batch_correct_text, inverse_batch_correct_text, batch_wrong_text), dim=0)
        cat_logits_all = self.order_mlp(cat_features_all)
        logits_order_inorder = cat_logits_all[:inorder_batch_correct_text.shape[0]]
        logits_order_inverse = cat_logits_all[inorder_batch_correct_text.shape[0]:2 * inorder_batch_correct_text.shape[0]]
        logits_order_wrong = cat_logits_all[2 * inorder_batch_correct_text.shape[0]:]

        logits_order_inverse = logits_order_inverse.index_select(1, torch.tensor([1, 0, 2]).to(logits_order_inverse.device))
        logits_order = torch.cat((logits_order_inorder, logits_order_inverse), dim=0)

        return logits_order_wrong, logits_order

    def summary(self):
        print('Trainable params:')
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name)

def build_model(state_dict: dict, T=8, use_checkpoint=False, logger=None, prompts_alpha=1e-1, mit_layers=4, 
bin_num=-1, progressor=False, lora=False):
    if 'model' in state_dict.keys():
        state_dict = state_dict['model']

    for keys in list(state_dict.keys()):
        if 'conv' in keys and 'weight' in keys:
            print(keys, state_dict[keys].shape)

    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    elif lora:
        vision_width = state_dict['visual.base_model.model.conv1.weight'].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.base_model.model.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.base_model.model.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    else:
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
    
    model = TimeRewarder(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,  
        T=T, mit_layers=mit_layers, bin_num=bin_num, progressor=progressor,
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    if lora:
        lora_config = peft.LoraConfig(
                r=8,
                lora_alpha=8,
                lora_dropout=0.0,
                target_modules=['in_proj_weight', "in_proj_bias", 'out_proj'],
            )

        model.visual = peft.get_peft_model(model.visual, lora_config)

        print('LORA Trainable params:')
        model.visual.print_trainable_parameters()

    msg = model.load_state_dict(state_dict,strict=False)
    if logger is not None:
        logger.info(f"load pretrained CLIP: {msg}")
    else:
        print(f"load pretrained CLIP: {msg}")
    
    return model.eval()

def load(model_path, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 
        T=3, use_checkpoint=False, logger=None, mit_layers=1, bin_num=-1, progressor=False,
):
    state_dict = torch.load(model_path, map_location="cpu")
    use_lora = 'lora' in model_path
    print('Use LORA:', use_lora, 'model path:', model_path)
    model = build_model(state_dict or model.state_dict(), T=T,
                        use_checkpoint=use_checkpoint, logger=logger,
                        mit_layers=mit_layers, bin_num=bin_num, progressor=progressor,
                        lora=use_lora,
                        )
    if str(device) == "cpu":
        model.float()
    return model, model.state_dict()


class CrossAttention(nn.Module):
    def __init__(self, query_dim=512, context_dim=768, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        if context_dim is None:
            context_dim = query_dim

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        if context is None:
            context = x
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if mask is not None:
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
        self.config = None

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x_cls = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x_cls = x_cls @ self.proj
        return x_cls, x[:, 1:, :]