from typing import Tuple, Union
import torch
from torch import nn, einsum
import numpy as np
from .mit import MultiframeIntegrationTransformer
from .mit_withtext import MultiframeIntegrationTransformerwithText
from .prompt import VideoSpecificPrompt
from .cct import CrossFrameCommunicationTransformer
import sys
import warnings
sys.path.append("../")
from clip.model import CLIP,LayerNorm,Transformer
import clip
from einops import rearrange, repeat
from .discretesupport import DiscreteSupport
import peft
import os

class TimeRewarder(CLIP):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int, 
                 # video
                 T=8, 
                 droppath=0.,
                 mlp_droprate=0.2,
                 mit_layers=1,
                 # prompt 
                 prompts_alpha=1e-4,
                 prompts_layers=1,
                 # other
                 use_cache=True,
                 use_checkpoint=False,
                 two_head=False,
                 train_order=True,
                 train_class=False,
                 use_similarity=False,
                 use_mit_withtext=True,
                 use_bin=True,
                 rank2reward=False,
                 progressor=False,
                 oneframe=False
                 ):
        super().__init__(
            embed_dim,
            image_resolution, vision_layers, vision_width, vision_patch_size,
            context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
        )
        
        self.T = T
        self.use_cache = use_cache
        self.two_head = two_head
        self.train_order = train_order
        self.train_class = train_class
        self.mlp_droprate = mlp_droprate
        self.use_similarity = use_similarity
        self.use_mit_withtext = use_mit_withtext
        self.use_bin = use_bin if not (progressor or rank2reward) else False
        self.rank2reward = rank2reward
        self.progressor = progressor
        self.oneframe = oneframe

        dpr = [x.item() for x in torch.linspace(0, droppath, vision_layers)] if droppath > 0. else None

        vision_heads = vision_width // 64
        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.cache_text_features = {}

        self.visual = VisionTransformer(
            input_resolution=image_resolution,
            patch_size=vision_patch_size,
            width=vision_width,
            layers=vision_layers,
            heads=vision_heads,
            output_dim=embed_dim
        )

        # print('Visual Trainable params:')
        # for name, param in self.visual.named_parameters():
        #     if param.requires_grad:
        #         print(name)


        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.order_attn = CrossAttention(query_dim=embed_dim, context_dim=768, heads=8, dim_head=64, dropout=0.0)

        self.order_ln = LayerNorm(embed_dim)

        self.mit = MultiframeIntegrationTransformer(T=T, embed_dim=embed_dim, layers=mit_layers,)

        if progressor:
            mlp_input_dim = 3 * embed_dim
        elif oneframe:
            mlp_input_dim = embed_dim
        else:
            mlp_input_dim = 2 * embed_dim
        
        if not use_bin:
            self.order_mlp = nn.Sequential(
                LayerNorm(mlp_input_dim),
                nn.Linear(mlp_input_dim, 3),
                nn.Dropout(self.mlp_droprate),
            )
        else:
            self.order_mlp = nn.Sequential(
                LayerNorm(mlp_input_dim),
                nn.Linear(mlp_input_dim, 22),
                nn.Dropout(self.mlp_droprate),
            )
            self.discrete_support = DiscreteSupport()
            
        self.initialize_parameters()

    
    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'positional_embedding'}

    def encode_text(self, text):
        x = self.token_embedding(text)
        eos_indx = text.argmax(dim=-1)
        K, N1, C = x.shape

        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection
        x = x.reshape(K, -1)
        return x

    def encode_video_with_text(self, image, text, clean=True):
        b,t,c,h,w = image.size()
        image = image.reshape(-1,c,h,w)
        # print('image:', image.shape)
        cls_features, img_features = self.visual(image)
        # b * num_clips * t
        if clean:
            # print('cls_features:', cls_features.shape)
            return cls_features
        text_img_features = self.order_attn(text, context=img_features)
        text_img_features = text_img_features.reshape(b, t, text_img_features.shape[-2], text_img_features.shape[-1])
        # b * num_clips, t, num_text, 512
        text_img_features = text_img_features.permute(0, 2, 1, 3)
        # b * num_clips, num_text, t, 512
        text_img_features = text_img_features.reshape(-1, t, text_img_features.shape[-1])
        video_features_with_text = self.mit(self.order_ln(text_img_features))
        # b * num_clips * num_text, 512
        
        return video_features_with_text

    def cache_text(self, text):
        self.eval()
        text_number = text.shape[0]
        with torch.no_grad():
            if text_number not in self.cache_text_features.keys():
                chunk_size = 2048
                text_chunks = text.split(chunk_size)
                text_features = []
                for chunk in text_chunks:
                    text_features.append(self.encode_text(chunk))
                self.cache_text_features[text_number] = torch.cat(text_features)
        self.train()
        return self.cache_text_features[text_number]
    
    def get_score_neighbor(self, image, text_feature):
        if len(image.shape) == 4:
            image_1 = torch.roll(image, 1, 0)
            image_1[0] = image[0]
            image_2 = torch.roll(image_1, 1, 0)
            image_2[0] = image_1[0]
            image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
        n, t, c, h, w = image.shape
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)
        video_features = video_features.reshape(n, -1)
        video_features_neighbor = torch.roll(video_features, 5, 0)
        # video_features_neighbor[0] = video_features[0]
        video_features_neighbor[:5] = video_features[0]

        inorder_features = torch.cat((video_features_neighbor, video_features), dim=-1)
        inverse_features = torch.cat((video_features, video_features_neighbor), dim=-1)
        cat_features = torch.cat((inorder_features, inverse_features), dim=0)
        logits = self.order_mlp(cat_features)
        logits = logits.softmax(dim=-1)
        logits = logits[:n,0]

        return logits


    def get_score(self, image, text_feature, progresspredict=True):
        if len(image.shape) == 4:
            image_1 = torch.roll(image, 1, 0)
            image_1[0] = image[0]
            image_2 = torch.roll(image_1, 1, 0)
            image_2[0] = image_1[0]
            image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
        # image (n, t, c, h, w);text feature (1, f)
        print(image.shape)
        n, t, c, h, w = image.shape
        text_feature = text_feature.unsqueeze(0).expand(n * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)
        video_features = video_features.reshape(n, -1)
        init_features = video_features[0].unsqueeze(0).expand(n, -1)
        half_features = video_features[:(n+1)//2]
        half_features = torch.repeat_interleave(half_features, 2, dim=0)[:n]
        neighbor_features = torch.roll(video_features, 1, 0)
        neighbor_features[0] = video_features[0]
        # neighbor_indices = torch.linspace(0, 0.75 * n, n).long()
        # neighbor_features = video_features[neighbor_indices]

        first_clip_features = torch.cat((init_features, half_features, init_features), dim=0)
        second_clip_features = torch.cat((half_features, video_features, video_features), dim=0)
        # second_clip_features = video_features.repeat(3, 1)
        clip_pair_features = torch.cat((first_clip_features, second_clip_features), dim=-1)
        inverse_pair_features = torch.cat((second_clip_features, first_clip_features), dim=-1)
        all_pair_features = torch.cat((clip_pair_features, inverse_pair_features), dim=0)
        # 3n, 2f
        logits = self.order_mlp(all_pair_features)
        if progresspredict:
            logits = logits[:3 * n,2] - logits[3 * n:,2]
        else:
            logits = logits.softmax(dim=-1)
            logits = logits[:3 * n,0] - logits[3 * n:,0]
        logits = logits.reshape(3, n)
        return logits


    def visualize(self, image, text_feature, num_clips=64, progresspredict=True):
        if len(image.shape) == 4:
            image_1 = torch.roll(image, 1, 0)
            image_1[0] = image[0]
            image_2 = torch.roll(image_1, 1, 0)
            image_2[0] = image_1[0]
            image = torch.cat((image_2.unsqueeze(1), image_1.unsqueeze(1), image.unsqueeze(1)), dim=1)
        n, t, c, h, w = image.shape
        sample_id = torch.linspace(0, n - 1, num_clips).long()
        image = image[sample_id]
        # image = image.reshape(-1, c, h, w)
        # image = self.img_norm(image)
        # image = image.reshape(num_clips, t, c, h, w)
        text_feature = text_feature.unsqueeze(0).expand(num_clips * t, 1, -1)
        video_features = self.encode_video_with_text(image, text_feature)

        video_features_expanded_1 = video_features.unsqueeze(1).expand(-1, num_clips, -1).reshape(num_clips*num_clips, -1)
        video_features_expanded_2 = video_features.repeat(num_clips, 1)
        video_feature_pairs = torch.cat((video_features_expanded_1, video_features_expanded_2), dim=-1)
        logits = self.order_mlp(video_feature_pairs).softmax(dim=-1)
        if progresspredict:
            logits_matrix = logits[:, 2].view(num_clips, num_clips)
            class_matrix = logits[:, 1].view(num_clips, num_clips)
            return logits_matrix, class_matrix
        else:
            logits_matrix = logits[:, 0].view(num_clips, num_clips)
            return logits_matrix, None


    def forward_3class(self, image, text, label_id=None, use_order=False, num_clips=4, label_id_neg=None, one_for_class=False, use_reverse=True, mix_order=False):
        if self.use_cache:
            text_features = self.cache_text(text)
        else:
            text_features = self.encode_text(text)
        
        if label_id_neg is not None and not self.two_head:
            b, neg_num = label_id_neg.shape
            label_id_neg = label_id_neg.view(-1)
            text_features = text_features[label_id_neg]
            text_features = text_features.view(b, neg_num, -1)
        else:
            raise NotImplementedError('label_id_neg is required and two_head model is not supported')

        # b * num_clips * (num_clips - 1)//2, 512
        
        if len(image.shape) == 5: # whole video as input
            b, t_all, c, h, w = image.shape
            step = 2
            image_unfold = image.unfold(1, self.T, step).permute(0, 1, 5, 2, 3, 4)
            sample_id = torch.randperm(image_unfold.shape[1])[:num_clips]
            image_unfold = image_unfold[:, sample_id]
            # (b, num_clips, t, c, h, w)
            image = image_unfold.reshape(b * num_clips, self.T, c, h, w)
        else:
            b, num_clips, t, c, h, w = image.shape
            image = image.reshape(b * num_clips, t, c, h, w)

        if use_reverse:
            image_reverse = torch.flip(image, [1])
        
        b, n, f = text_features.shape
        text_features = text_features.unsqueeze(1).expand(b, 2 * num_clips * self.T, n, f).reshape(b * 2 * num_clips * self.T, n, f)
        # 2 * b * num_clips, num_text, 512
        if use_reverse:
            image_batch = torch.cat((image, image_reverse), dim=0)
            video_features_batch = self.encode_video_with_text(image_batch, text_features)
            video_features = video_features_batch[:b * num_clips]
            video_features_reverse = video_features_batch[b * num_clips:]
            # b * num_clips * num_text, 512
            video_features = video_features.reshape(b, num_clips, n, f)
            video_features_reverse = video_features_reverse.reshape(b, num_clips, n, f)
        else:
            video_features = self.encode_video_with_text(image, text_features)
            video_features = video_features.reshape(b, num_clips, n, f)
        for i in range(num_clips):
            for j in range(i+1, num_clips):
                video_features_1 = video_features[:, i]
                video_features_2 = video_features[:, j]
                if use_reverse:
                    video_features_inverse_1 = video_features_reverse[:, i]
                    video_features_inverse_2 = video_features_reverse[:, j]
                if i == 0 and j == 1:
                    video_features_1_batch = video_features_1
                    video_features_2_batch = video_features_2
                    if use_reverse:
                        video_features_inverse_1_batch = video_features_inverse_1
                        video_features_inverse_2_batch = video_features_inverse_2
                else:
                    video_features_1_batch = torch.cat((video_features_1_batch, video_features_1), dim=0)
                    video_features_2_batch = torch.cat((video_features_2_batch, video_features_2), dim=0)
                    if use_reverse:
                        video_features_inverse_1_batch = torch.cat((video_features_inverse_1_batch, video_features_inverse_1), dim=0)
                        video_features_inverse_2_batch = torch.cat((video_features_inverse_2_batch, video_features_inverse_2), dim=0)
                    # b * n(n-1)/2, num_text, 512
        video_feature_batch_inorder = torch.cat((video_features_1_batch, video_features_2_batch), dim= -1)
        video_feature_batch_inverse = torch.cat((video_features_2_batch, video_features_1_batch), dim= -1)
        if use_reverse:
            video_feature_inverse_batch_inorder = torch.cat((video_features_inverse_1_batch, video_features_inverse_2_batch), dim= -1)
            video_feature_inverse_batch_inverse = torch.cat((video_features_inverse_2_batch, video_features_inverse_1_batch), dim= -1)
            batch_inverse_image = torch.cat((video_feature_inverse_batch_inorder[:, 0], video_feature_inverse_batch_inverse[:, 0]), dim=0)
            if mix_order:
                video_feature_mixorder_batch = torch.cat((video_features_1_batch, video_features_inverse_2_batch), dim= -1)
                video_feature_inverse_mixorder_batch = torch.cat((video_features_inverse_1_batch, video_features_2_batch), dim= -1)
                batch_inverse_image = torch.cat((video_feature_inverse_batch_inorder[:, 0], video_feature_inverse_batch_inverse[:, 0], video_feature_mixorder_batch[:, 0], video_feature_inverse_mixorder_batch[:, 0]), dim=0)
        # b * n(n-1)/2, num_text, 1024
        inorder_batch_correct_text = video_feature_batch_inorder[:, 0]
        inverse_batch_correct_text = video_feature_batch_inverse[:, 0]
        # b * n(n-1)/2, 1024
        if neg_num == 1:
            neg_features = torch.cat((inverse_batch_correct_text, batch_inverse_image), dim=0)
            sampled_neg_features = neg_features[torch.randperm(neg_features.size(0))[:inorder_batch_correct_text.size(0)]]
            cat_features_all = torch.cat((inorder_batch_correct_text, sampled_neg_features), dim=0)
            cat_logits_all = self.order_mlp(cat_features_all)
            logits_order_inorder = cat_logits_all[:inorder_batch_correct_text.shape[0]]
            logits_order_inverse = cat_logits_all[inorder_batch_correct_text.shape[0]:]
            logits_order_inverse = logits_order_inverse.index_select(1, torch.tensor([1, 0, 2]).to(logits_order_inverse.device))
            logits_order = torch.cat((logits_order_inorder, logits_order_inverse), dim=0)
            
            return None, logits_order

        inorder_batch_wrong_text = video_feature_batch_inorder[:, 1:].reshape(-1, video_feature_batch_inorder.shape[-1])
        inverse_batch_wrong_text = video_feature_batch_inverse[:, 1:].reshape(-1, video_feature_batch_inverse.shape[-1])
        batch_wrong_text = torch.cat((inorder_batch_wrong_text, inverse_batch_wrong_text), dim=0)

        cat_features_all = torch.cat((inorder_batch_correct_text, inverse_batch_correct_text, batch_wrong_text), dim=0)
        cat_logits_all = self.order_mlp(cat_features_all)
        logits_order_inorder = cat_logits_all[:inorder_batch_correct_text.shape[0]]
        logits_order_inverse = cat_logits_all[inorder_batch_correct_text.shape[0]:2 * inorder_batch_correct_text.shape[0]]
        logits_order_wrong = cat_logits_all[2 * inorder_batch_correct_text.shape[0]:]

        logits_order_inverse = logits_order_inverse.index_select(1, torch.tensor([1, 0, 2]).to(logits_order_inverse.device))
        logits_order = torch.cat((logits_order_inorder, logits_order_inverse), dim=0)

        return logits_order_wrong, logits_order
    
    def forward(self, image, text, label_id=None, use_order=False, num_clips=4, label_id_neg=None, progress=None, one_for_class=False, use_reverse=True, mix_order=False, show_bin=False, no_neg=False):
        if self.use_cache:
            text_features = self.cache_text(text)
        else:
            text_features = self.encode_text(text)
        
        if label_id_neg is not None and not self.two_head:
            b, neg_num = label_id_neg.shape
            label_id_neg = label_id_neg.view(-1)
            text_features = text_features[label_id_neg]
            text_features = text_features.view(b, neg_num, -1)
        else:
            raise NotImplementedError('label_id_neg is required and two_head model is not supported')

        
        if len(image.shape) < 6:
            raise ValueError('lenth of image shape should be 6, current shape:', image.shape)
        b, num_clips, t, c, h, w = image.shape
        image = image.reshape(b * num_clips, t, c, h, w)
        image_reverse = torch.flip(image, [1])
        image_batch = torch.cat((image, image_reverse), dim=0)

        b, num_clips = progress.shape
        b, n, f = text_features.shape
        text_features = text_features.unsqueeze(1).expand(b, 2 * num_clips * self.T, n, f).reshape(2 * b * num_clips * self.T, n, f)

        video_features_batch = self.encode_video_with_text(image_batch, text_features)
        video_features = video_features_batch[:b * num_clips]
        video_features_reverse = video_features_batch[b * num_clips:]
        # print('video_features:', video_features.shape, 'video_features_reverse:', video_features_reverse.shape)
        video_features = video_features.reshape(b, num_clips, n, f)
        video_features_reverse = video_features_reverse.reshape(b, num_clips, n, f)
        
        if self.progressor:
            video_features_1 = video_features[:, 0]
            video_features_2 = video_features[:, 1]
            video_features_3 = video_features[:, 2]
            progress_interval = (progress[:, 1] - progress[:, 0]) / (progress[:, 2] - progress[:, 0])
            video_features_batch = torch.cat((video_features_1, video_features_2, video_features_3), dim=-1).squeeze(1)
            logits_order = self.order_mlp(video_features_batch)
            # print('video_features_batch:', video_features_batch.shape, 'logits_order:', logits_order.shape)
            regression_logits = logits_order[:, 2]
            # print('regression_logits:', regression_logits.shape, 'progress_interval:', progress_interval.shape)
            regression_labels = progress_interval
            regression_loss = nn.MSELoss()(regression_logits, regression_labels)
            classification_loss = torch.zeros_like(regression_loss)
            combined_loss = regression_loss

            return combined_loss, classification_loss, regression_loss, logits_order
        
        elif self.oneframe:
            video_features_1 = video_features[:, 1]
            progress_interval = progress[:, 1]
            video_features_batch = video_features_1
            logits_order = self.order_mlp(video_features_batch)
            regression_labels = progress_interval
            if not self.use_bin:
                regression_logits = logits_order[:, 0, 2]
                regression_loss = nn.MSELoss()(regression_logits, regression_labels)
            else:
                regression_logits = logits_order[:, 0, 2:]
                regression_labels_bin = self.discrete_support.scalar_to_vector(regression_labels)
                # print('regression_logits:', regression_logits.shape, 'regression_labels_bin:', regression_labels_bin.shape)
                regression_loss = nn.CrossEntropyLoss()(regression_logits, regression_labels_bin)
            classification_loss = torch.zeros_like(regression_loss)
            combined_loss = regression_loss
            return combined_loss, classification_loss, regression_loss, logits_order


        for i in range(num_clips):
            for j in range(i+1, num_clips):
                video_features_1 = video_features[:, i]
                video_features_2 = video_features[:, j]
                video_features_inverse_1 = video_features_reverse[:, i]
                video_features_inverse_2 = video_features_reverse[:, j]
                progress_interval = progress[:, j] - progress[:, i]
                if i == 0 and j == 1:
                    video_features_1_batch = video_features_1
                    video_features_2_batch = video_features_2
                    video_features_inverse_1_batch = video_features_inverse_1
                    video_features_inverse_2_batch = video_features_inverse_2
                    progress_label = progress_interval
                else:
                    video_features_1_batch = torch.cat((video_features_1_batch, video_features_1), dim=0)
                    video_features_2_batch = torch.cat((video_features_2_batch, video_features_2), dim=0)
                    video_features_inverse_1_batch = torch.cat((video_features_inverse_1_batch, video_features_inverse_1), dim=0)
                    video_features_inverse_2_batch = torch.cat((video_features_inverse_2_batch, video_features_inverse_2), dim=0)
                    progress_label = torch.cat((progress_label, progress_interval), dim=0)
                    # b * n(n-1)/2, num_text, 512
        video_feature_batch_inorder = torch.cat((video_features_1_batch, video_features_2_batch), dim= -1)
        video_feature_batch_inverse = torch.cat((video_features_inverse_2_batch, video_features_inverse_1_batch), dim= -1)
        # b * n(n-1)/2, num_text, 1024
        inorder_batch_correct_text = video_feature_batch_inorder[:, 0]
        inverse_batch_correct_text = video_feature_batch_inverse[:, 0]
        # b * n(n-1)/2, 1024
        if neg_num == 1:
            cat_features_all = torch.cat((inorder_batch_correct_text, inverse_batch_correct_text), dim=0)
            cat_logits_all = self.order_mlp(cat_features_all)
            logits_order = cat_logits_all
            
            # Classification loss
            classification_labels = torch.zeros(cat_features_all.shape[0]).long().to(logits_order.device)
            classification_loss = nn.CrossEntropyLoss()(logits_order[:, :2], classification_labels)
            
            # Regression loss
            if not self.use_bin:
                regression_logits = logits_order[:, 2]
            else:
                regression_logits = logits_order[:, 2:]
            regression_labels = torch.cat((progress_label, - progress_label), dim=0).to(regression_logits.device)
            if self.rank2reward:
                regression_labels = (regression_labels > 0).float()
            if self.use_bin:
                regression_labels_bin = self.discrete_support.scalar_to_vector(regression_labels)
                if no_neg:
                    regression_loss = nn.CrossEntropyLoss()(regression_logits[:len(progress_label)], regression_labels_bin[:len(progress_label)])
                else:
                    regression_loss = nn.CrossEntropyLoss()(regression_logits, regression_labels_bin)
            else:
                if self.rank2reward:
                    regression_loss = nn.BCEWithLogitsLoss()(regression_logits, regression_labels)
                regression_loss = nn.MSELoss()(regression_logits, regression_labels)
            # Combined loss
            combined_loss = classification_loss + 0.5 * regression_loss

            if not show_bin:
                return combined_loss, classification_loss, regression_loss, logits_order
            
            with torch.no_grad():
                intervals = [
                    (-1.0, -0.5),
                    (-0.5, -0.25),
                    (-0.25, -0.125),
                    (-0.125, -0.0625),
                    (-0.0625, 0),
                    (0, 0.0625),
                    (0.0625, 0.125),
                    (0.125, 0.25),
                    (0.25, 0.5),
                    (0.5, 1.0),
                ]
                bin_metrics = {}
                r_logits = regression_logits.detach()
                if self.use_bin:
                    r_logits = self.discrete_support.vector_to_scalar(r_logits)
                r_labels = regression_labels.detach()
                for i, (low, high) in enumerate(intervals):
                    mask = (r_labels >= low) & (r_labels < high)
                    if mask.any():
                        selected_logits = r_logits[mask]
                        selected_labels = r_labels[mask]
                        opposite_sign_mask = (selected_logits * selected_labels) < 0
                        sign_error_bin = opposite_sign_mask.float().mean()
                        mae_bin = torch.mean(torch.abs(selected_logits - selected_labels))
                        interval_str = f"[{low},{high}]"
                        bin_metrics[f'bin_{interval_str}_sign_error'] = sign_error_bin.item()
                        bin_metrics[f'bin_{interval_str}_mae'] = mae_bin.item()
            
            return combined_loss, classification_loss, regression_loss, logits_order, bin_metrics
        else:
            inorder_batch_wrong_text = video_feature_batch_inorder[:, 1:].reshape(-1, video_feature_batch_inorder.shape[-1])
            inverse_batch_wrong_text = video_feature_batch_inverse[:, 1:].reshape(-1, video_feature_batch_inverse.shape[-1])
            batch_wrong_text = torch.cat((inorder_batch_wrong_text, inverse_batch_wrong_text), dim=0)

            cat_features_all = torch.cat((inorder_batch_correct_text, inverse_batch_correct_text, batch_wrong_text), dim=0)
            cat_logits_all = self.order_mlp(cat_features_all)

            # Classification loss
            classification_labels = torch.cat((
                torch.zeros(inorder_batch_correct_text.shape[0]).long(),
                torch.zeros(inverse_batch_correct_text.shape[0]).long(),
                torch.ones(batch_wrong_text.shape[0]).long()
            )).to(logits_order.device)
            classification_loss = nn.CrossEntropyLoss()(logits_order[:, :2], classification_labels)
            
            # Regression loss
            regression_logits = logits_order[:2 * inorder_batch_correct_text.shape[0], 2]
            regression_labels = torch.cat((progress_label, -progress_label), dim=0).to(regression_logits.device)
            regression_loss = nn.MSELoss()(regression_logits, regression_labels)
            
            # Combined loss
            combined_loss = classification_loss + 0.5 * regression_loss

            return combined_loss, classification_loss, regression_loss, logits_order

    def summary(self):
        print('Trainable params:')
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name)

def build_model(state_dict: dict, T=8, droppath=0., use_checkpoint=False, logger=None, prompts_alpha=1e-1, prompts_layers=2, 
    use_cache=True, mit_layers=4, train_order=True, train_class=False, two_head=False, mlp_droprate=0.2, lora=True):
    if 'model' in state_dict.keys():
        state_dict = state_dict['model']

    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    else:
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
    
    model = TimeRewarder(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,  
        T=T, droppath=droppath, mlp_droprate=mlp_droprate, mit_layers=mit_layers,
        prompts_alpha=prompts_alpha, prompts_layers=prompts_layers,
        use_checkpoint=use_checkpoint, use_cache=use_cache,
        train_order=train_order, train_class=train_class,
        two_head=two_head,
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    if lora:
        lora_config = peft.LoraConfig(
                r=8,
                lora_alpha=8,
                lora_dropout=0.0,
                target_modules=['in_proj_weight', "in_proj_bias", 'out_proj'],
            )

        model.visual = peft.get_peft_model(model.visual, lora_config)

        print('LORA Trainable params:')
        model.visual.print_trainable_parameters()

    msg = model.load_state_dict(state_dict,strict=False)
    if logger is not None:
        logger.info(f"load pretrained CLIP: {msg}")
    else:
        print(f"load pretrained CLIP: {msg}")

    # print('CLIP Trainable params:')
    # for name, param in model.visual.named_parameters():
    #     if param.requires_grad:
    #         print(name)
    
    return model.eval()


def load(model_path, name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 
         jit=True, T=8, droppath=0., use_checkpoint=False, logger=None, use_cache=True, prompts_alpha=1e-1, prompts_layers=2, mit_layers=1,
         train_order=True, train_class=False, two_head=False, mlp_droprate=0.2
):
    if model_path is None:
        model_path = clip._download(clip._MODELS[name])
    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
        state_dict = None
    except RuntimeError:
        # loading saved state dict
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(model_path, map_location="cpu")
    
    if model_path is not None:
        use_lora = 'lora' in model_path

    model = build_model(state_dict or model.state_dict(), T=T, droppath=droppath, 
                        use_checkpoint=use_checkpoint, logger=logger,
                        prompts_alpha=prompts_alpha, 
                        prompts_layers=prompts_layers,
                        use_cache=use_cache,
                        mlp_droprate=mlp_droprate,
                        mit_layers=mit_layers,
                        train_order=train_order, 
                        train_class=train_class,
                        two_head=two_head,
                        lora=use_lora,
                        )

    if str(device) == "cpu":
        model.float()
    return model, model.state_dict()

class CrossAttention(nn.Module):
    def __init__(self, query_dim=512, context_dim=768, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        if context_dim is None:
            context_dim = query_dim

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        if context is None:
            context = x
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if mask is not None:
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
        self.config = None

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x_cls = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x_cls = x_cls @ self.proj
        return x_cls, x[:, 1:, :]