from ast import Pass
from .CLIP.clip import clip
from .prompting_module import PromptModule
from .temporal_module import Temporal_Module
from .temporalShiftModule.ops.utils import AverageMeter, accuracy
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import random
from torchvision import transforms
from .temporalShiftModule.ops.transforms import *
import torch.nn.functional as F
import os
from copy import deepcopy
from scipy.optimize import linear_sum_assignment
import shap
from .text_template import template_dict

from ptflops import get_model_complexity_info


class DummyModel:
    def predict(self, x): 
        if isinstance(x, np.ndarray):
            return x.mean(1)
        return torch.mean(x, dim=1)


class VideoFeatureSelection:
    def __init__(self, model, topk=3, device=None):
        self.device = device
        self.mean_model = model
        self.topk = topk
        

    def compute_shap_values(self, videos_features, background=None):
        B,N,D = videos_features.size(0), videos_features.size(1), videos_features.size(2)
        if background is None:
            background = background = torch.zeros(1, D).cpu().numpy() 
        explainer = shap.KernelExplainer(self.mean_model.predict, background)
        videos_features_np = videos_features.detach().cpu().numpy()
     
        shap_values_list = []
        for i in range(B):
            shap_val = explainer.shap_values(videos_features_np[i]) 
            shap_val = np.array(shap_val)
            if shap_val.ndim == 1: 
                shap_val = shap_val[:, np.newaxis] 
            shap_values_list.append(shap_val)

        shap_values = np.stack(shap_values_list, axis=0) 
        return shap_values

    def select_topk_frames(self, videos_features, shap_values=None):
        
        topk_idx = torch.argsort(shap_values, dim=1, descending=True)[:, :self.topk] 
        selected_frames = []
        
        for idx in range(videos_features.size(0)):  
            selected_frames.append(videos_features[idx, topk_idx[idx]])  

        selected_frames = torch.stack(selected_frames, dim=0) 
        return selected_frames

    def get_final_context_embedding(self, videos_features):

        shap_values = self.compute_shap_values(videos_features)

        selected_frames = self.select_topk_frames(videos_features, shap_values)

        context_emb = torch.mean(selected_frames, dim=1) 
        
        return context_emb

class BaselineModel(nn.Module):
    
    def __init__(self, device, conf, experiment):
        super(BaselineModel, self).__init__()
    
        self.conf = conf
        self.experiment = experiment 
        self.device = device
        type_clip_model = conf['feature_encoder']['type_clip_model'] 
        type_mod = conf['modulation_module']['type_mod'] 
        self.type_mod = type_mod
        self.num_segments = conf['feature_encoder']['num_segments']  
        self.use_text_template = conf['use_text_template']
        self.use_mean_text_template = conf['use_mean_text_template']
        self.template_dict = template_dict["imagenet_R_templates"]
        # tem_prompt
        self.is_tem_prompts = conf['feature_encoder']['is_tem_prompts']
        self.is_auxiliary_training = conf['feature_encoder']['is_auxiliary_training']

        by_instance = conf['modulation_module']['mod_by_instance']
        
        L_tx_gn = conf['modulation_module']['prompt_module']['length_tex_gn'] 

        # Adapter
        self.Adapter = conf['Adapter']
        self.is_use_adapter = conf['Adapter']['is_use_adapter']
        self.is_only_use_adapter = conf['Adapter']['is_only_use_adapter']
        self.is_cosin_rounter = conf['Adapter']['is_cosin_rounter']

        self.clip_model, _ = clip.load(type_clip_model, device=device, type_mod = type_mod, L_tx_gn = L_tx_gn, is_use_adapter=self.is_use_adapter, config=self.Adapter)
        self.enable_temporal_module = conf['feature_encoder']['enable_temporal_module'] 
        

        for param in self.clip_model.parameters(): 
            param.requires_grad = False
        if self.is_use_adapter:
            for block in self.clip_model.visual.transformer.resblocks:
                block.adapter.train()
                for param in block.adapter.parameters():
                    param.requires_grad = True
        
        width_vid = self.clip_model.vision_width  
        width_txt = self.clip_model.transformer.width 
        width_temp = conf['feature_encoder']['dim_model']  
        self.width_temp = width_temp 
        self.crop_size = self.clip_model.visual.input_resolution  
        self.scale_size = self.crop_size * 256 // 224 

        self.input_mean = [0.48145466, 0.4578275, 0.40821073]
        self.input_std = [0.26862954, 0.26130258, 0.27577711]

        self.num_classes = 0
        self.memory = {}
        self.type_cls = conf['type_loss']
        self.list_val_acc_ii = {'val': [], 'test': []}  
        self.num_training_phases = conf['num_training_phases'] 

        self.prompt_module = None
        self.pre_pro_train_mode = False
        self.curr_pro_train_mode = False
        self.type_task = conf['type_task']
        self.training_phase_task_selector = conf['training_phase_task_selector'] 
        self.teacher_forcing = conf['modulation_module']['teacher_forcing']
        self.weights = conf['weights']
        self.val_weights = conf['val_weights']
        if type_mod == 'Prompt':
            conf_prompt_module = conf['modulation_module']['prompt_module']
            self.pre_pro_train_mode = conf_prompt_module['pre_pro_train_mode']
            self.curr_pro_train_mode = conf_prompt_module['curr_pro_train_mode']
            self.type_prompt = conf_prompt_module['type_prompt']
            self.prompt_module = PromptModule(conf_prompt_module, width_vid, width_temp, width_txt, self.type_prompt, self.num_segments, self.type_task, device)

        self.classes = None
        self.task_id = 0

        self.is_distill = conf['pos_knowledge_preserved']['is_distill']

        self.is_pos_distill = conf['pos_knowledge_preserved']['is_pos_distill']
        self.old_temporal_module = None
        if self.enable_temporal_module:
            if self.is_auxiliary_training:
                tem_prompts = self.encode_Tem_Prompts()
                self.temporal_module = Temporal_Module(device, conf['feature_encoder'], tem_prompts=tem_prompts, task_id=self.task_id)
            else:
                self.temporal_module = Temporal_Module(device, conf['feature_encoder']) 
            self.temporal_module = self.temporal_module.to(self.device)

            pytorch_total_params = sum(p.numel() for p in self.temporal_module.parameters() if p.requires_grad)
            print('num params of temp_module: ',pytorch_total_params)

        self.is_use_shape = conf['shap']['is_use_shape']
        self.num_valid_frames = conf['feature_encoder']['num_valid_frames']
        self.is_select = conf['shap']['is_select']
        self.is_weight_mean = conf['shap']['is_weight_mean']

        self.is_use_MoE = conf['MoE']['is_use_MoE']
        self.avg_model = Temporal_Module(device, conf['feature_encoder'])  
        self.avg_model = self.avg_model.to(self.device)
        self.task_count = 1
        self.conf_feature_encoder = conf['feature_encoder']
        self.is_save_model = conf['MoE']['is_save_model']
        self.all_model = []
        self.diff_all_model = []
        self.is_use_diff_feat = conf['MoE']['is_use_diff_feat']
        self.diff_feat_module = None
        self.mean_diff_feat = []
       
        self.setting = conf['dataset']
        self.num_tasks = self.setting['num_tasks']

        self.imp_reg = conf['imp_reg']
        self.is_use_imp_reg = conf['imp_reg']['is_use_imp_reg']
        self.class_stats = {}
        self.old_backbone = None
        self.weight_distill = None
        self.eps = 1e-8


    def weighted_feature_sum(self, features_tensor):
        mean_feature = torch.mean(features_tensor, dim=1)
        features_norm = torch.nn.functional.normalize(features_tensor, p=2, dim=2)
        mean_norm = torch.nn.functional.normalize(mean_feature, p=2, dim=1).unsqueeze(1)
        
        cosine_sim = torch.bmm(features_norm, mean_norm.transpose(1, 2)).squeeze(2)
        cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)
        
        weights = cosine_sim / torch.sum(cosine_sim, dim=1, keepdim=True)
        weighted_sum = torch.sum(features_tensor * weights.unsqueeze(2), dim=1)
        
        return weighted_sum
            

    def update_average(self):
        if self.is_save_model:
            self.all_model.append(deepcopy(self.temporal_module).to(self.device))
            self.task_count += 1
            self.temporal_module = Temporal_Module(self.device, self.conf_feature_encoder).to(self.device)
        else:
            for avg_param, new_param in zip(self.avg_model.parameters(), self.temporal_module.parameters()):
                avg_param.data = (self.task_count - 1) * avg_param.data + new_param.data
                avg_param.data /= self.task_count
            self.task_count += 1
            self.temporal_module = Temporal_Module(self.device, self.conf_feature_encoder).to(self.device)

    def compute_video_feature(self, frames, top_k=4):
        batch_size = frames.shape[0]
        frames_normalized = F.normalize(frames, p=2, dim=2) 
        similarity_matrix = torch.matmul(frames_normalized, frames_normalized.transpose(1, 2))  
        mask = torch.ones_like(similarity_matrix) - torch.eye(8, device=self.device).unsqueeze(0) 
        masked_similarity = similarity_matrix * mask  
        importance_scores = torch.sum(masked_similarity, dim=2) / 7 
        weights = F.softmax(importance_scores, dim=1)
        _, topk_indices = torch.topk(weights, k=top_k, dim=1)  
        batch_indices = torch.arange(batch_size).unsqueeze(1) 
        topk_frames = frames[batch_indices, topk_indices] 
        video_feature = torch.mean(topk_frames, dim=1) 
        video_feature = torch.sum(weights.unsqueeze(2) * frames, dim=1)  
        return video_feature
        
    def select_topk_frames(self, video_feat, text_feat=None, topk=3):
        batch_size, num_frames, feat_dim = video_feat.shape
        if text_feat is not None:
            cos_sim = torch.nn.functional.cosine_similarity(video_feat, text_feat.unsqueeze(1), dim=-1)
        else:
            mean_video = torch.mean(video_feat, dim=1) 
            cos_sim = torch.nn.functional.cosine_similarity(video_feat, mean_video.unsqueeze(1), dim=-1)
        topk_indices = torch.topk(cos_sim, k=topk, largest=True).indices 
        topk_video_feat = torch.gather(video_feat, dim=1, index=topk_indices.unsqueeze(-1).expand(-1, -1, feat_dim))  
        return topk_video_feat
    
    def ContrastLoss(self, labels, logits_per_image, logits_per_text):
        eps = 1e-8
        label_mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0))

        i2t_logits = logits_per_image - torch.max(logits_per_image, dim=1, keepdim=True)[0].detach()  
        i2t_pos = torch.sum(i2t_exp * label_mask, dim=1)  
        i2t_all = torch.sum(i2t_exp, dim=1)               
        loss_i2t = -torch.log((i2t_pos + eps) / (i2t_all + eps)).mean()

        t2i_logits = logits_per_text - torch.max(logits_per_text, dim=1, keepdim=True)[0].detach()
        t2i_exp = torch.exp(t2i_logits)
        t2i_pos = torch.sum(t2i_exp * label_mask.T, dim=1)
        t2i_all = torch.sum(t2i_exp, dim=1)               
        loss_t2i = -torch.log((t2i_pos + eps) / (t2i_all + eps)).mean()
        return (loss_i2t + loss_t2i)/2


    def HungarianLoss(self, pred_logits, groundtruth=None): 
        batch_size, num, _ = pred_logits.shape
        total_loss = 0.0
        eps = 1e-12
        predicted_classes = torch.argmax(pred_logits, dim=-1) 
        if groundtruth is not None:
            groundtruth = groundtruth.expand(batch_size, -1).to(self.device) 
        else:
            groundtruth = torch.arange(num).expand(batch_size, -1).to(self.device)  
        correct_predictions = (predicted_classes == groundtruth) 
        correct_count = correct_predictions.sum() 
        total_count = groundtruth.numel() 
        accuracy = correct_count.float() / total_count  

        pos_loss = F.cross_entropy(
                    pred_logits.view(-1, 8), 
                    groundtruth.view(-1)
                )
        return pos_loss / batch_size, accuracy
    
    
    def add_num_classes(self, num_next_classes):
        self.num_classes+=num_next_classes 

    def prepare_for_next_classes(self, num_next_classes):
        if num_next_classes != None:
            self.add_num_classes(num_next_classes)
            if self.type_cls == 'Linear':
                self.create_fc()
                print('Classifier augmented')

    def get_augmentation(self, flip=True): 
        if flip:
            return transforms.Compose([GroupMultiScaleCrop(self.crop_size, [1, .875, .75, .66]),
                                                    GroupRandomHorizontalFlip(is_flow=False)])
        else:
            print('#' * 20, 'NO FLIP!!!')
            return transforms.Compose([GroupMultiScaleCrop(self.crop_size, [1, .875, .75, .66])])
      
    def add_samples_to_mem(self, data, m):
        
        self.memory = {**self.memory, **data}
        for class_id, videos in self.memory.items():
            random.shuffle(videos)  
            if m != 'ALL':
                self.memory[class_id] = videos[:m]
            else:
                self.memory[class_id] = videos 

        for class_id, videos in self.memory.items():
            print('Memory... Class: {}, num videos: {}'.format(class_id, len(videos)))

    def create_fc(self):
        if self.type_cls == 'Linear':
            in_features = self.clip_model.output_dim
            self.classifier = nn.Linear(in_features, self.num_classes)
            self.classifier.to(self.device)

    def create_init_key(self, classes):
        list_curr_classes = []
        for cls in classes:
            if not cls in self.prompt_module.cls_to_task_id:
                list_curr_classes.append(cls)
        classes_emb = self.encode_labels(list_curr_classes, modulate = False)  
        task_emb_init = torch.mean(classes_emb, dim=0, keepdim=True)
        return task_emb_init
        


    def prepare_trainining(self, classes, task_id, training_phase = 1, pre_pro_train_mode = False, curr_pro_train_mode = True):
        if self.num_training_phases == 1:
            if self.type_mod == 'Prompt' and self.type_prompt != 'general':
                self.prompt_module.prepare_task_prompt(classes, task_id)
                if self.prompt_module.L_sp_tk > 0:
                    self.prompt_module.set_train_mode_task_prompts(task_id, 'ViT', pre_pro_train_mode, curr_pro_train_mode)
                if self.prompt_module.L_sp_tk > 0:
                    self.prompt_module.set_train_mode_task_prompts(task_id, 'temp', pre_pro_train_mode, curr_pro_train_mode)
            if self.is_auxiliary_training and self.is_tem_prompts:
                if task_id == 0:
                    for p in self.temporal_module.position_head.parameters():
                        p.requires_grad = True
                    self.temporal_module.tem_prompts.requires_grad = True
                else:
                    for p in self.temporal_module.position_head.parameters():
                        p.requires_grad = False
                    self.temporal_module.tem_prompts.requires_grad = False

        else:       
            if training_phase == 1 or training_phase == 3:  
                if self.type_mod == 'Prompt' and self.type_prompt != 'general':
                    if self.prompt_module.L_sp_tk > 0:
                        self.prompt_module.set_train_mode_task_prompts(task_id, 'ViT', False, False)
                    if self.prompt_module.L_tp_tk > 0:
                        self.prompt_module.set_train_mode_task_prompts(task_id, 'temp', False, False)
                if self.enable_temporal_module:
                    for param in self.temporal_module.parameters():
                        param.requires_grad = True
              
            else:  
                if self.type_mod == 'Prompt' and self.type_prompt != 'general':
                    self.prompt_module.prepare_task_prompt(classes, task_id)
                    if self.prompt_module.L_sp_tk > 0:
                        self.prompt_module.set_train_mode_task_prompts(task_id, 'ViT', pre_pro_train_mode, curr_pro_train_mode)
                    if self.prompt_module.L_tp_tk > 0:
                        self.prompt_module.set_train_mode_task_prompts(task_id, 'temp', pre_pro_train_mode, curr_pro_train_mode)
                if self.enable_temporal_module:
                    for param in self.temporal_module.parameters():
                        param.requires_grad = False

    def get_optimizer(self, training_phase = 1):
        print('here4 optimizer')
        params_out = []
        if self.num_training_phases == 1:
            print('here 1 phases optimizer')
            if self.type_mod == 'Prompt':
                for p in self.prompt_module.parameters():
                    # print('p: ',p)
                    params_out.append(p)

            if self.type_cls == 'Linear':
                print('adding linear cls parameters')
                for p in self.classifier.parameters():
                    params_out.append(p)
            
            if self.enable_temporal_module:
                print('adding temp parameters')
                for p in self.temporal_module.parameters():
                    if p.requires_grad:
                        params_out.append(p)
                    else:
                        print("p.requires_grad")
            if self.is_use_adapter:
                for block in self.clip_model.visual.transformer.resblocks:
                    for p in block.adapter.parameters():
                        if p.requires_grad:
                            params_out.append(p)
                        else:
                            print("p.requires_grad")
        else:
            print('here 2 phases optimizer')
            if training_phase == 1 or training_phase == 3: 
                print('here first phase')
                if self.enable_temporal_module:
                    print('adding temp parameters')
                    for p in self.temporal_module.parameters():
                        # print('p: ',p)
                        params_out.append(p)
                if self.type_cls == 'Linear':
                    print('adding linear cls parameters')
                    for p in self.classifier.parameters():
                        # print('p: ',p)
                        params_out.append(p)
            else:
                print('here second phase')  
                if self.type_mod == 'Prompt':
                    print('adding prompting parameters')
                    for p in self.prompt_module.parameters():
                        # print('p: ',p)
                        params_out.append(p)
                if self.type_cls == 'Linear':
                    print('adding temp parameters')
                    for p in self.classifier.parameters():
                        params_out.append(p)

        self.optimizer = torch.optim.SGD(params_out, lr=self.conf['lr'])

    
    def set_losses(self, cls_loss, text_con_loss = None):
        self.cls_loss = cls_loss  
        self.text_con_loss = text_con_loss

    def forward_selector(self, videos, classes, modulate_txt):
        videos = videos.to(self.device)
        videos = videos.view((-1, 3) + videos.size()[-2:]) 
        videos_features = self.clip_model.encode_image(videos)
        videos_features = videos_features.view(videos_features.size(0) // self.num_segments, self.num_segments, -1)

        if self.enable_temporal_module:
            video_emb = self.temporal_module(videos_features)  
        else:
            video_emb = torch.mean(videos_features, dim=1)  

        
        dict_saved_classes = self.prompt_module.cls_to_task_id 
        classes = list(classes)
        learned_classes = []
        for cls in classes:
            if cls in dict_saved_classes or cls.replace(' ', '') in dict_saved_classes:
                learned_classes.append(cls) 

        classes_emb = self.encode_labels(learned_classes, modulate_txt)
        video_emb = video_emb.unsqueeze(dim=1)
        video_emb = video_emb.expand(video_emb.size(0), classes_emb.size(0), video_emb.size(2))  
        classes_emb = classes_emb.unsqueeze(dim=0)
        classes_emb = classes_emb.expand(video_emb.size(0), classes_emb.size(1), classes_emb.size(2)) 
        output = F.cosine_similarity(video_emb, classes_emb, dim=2) 

        _, class_ids = torch.max(output, 1)
        dict_saved_classes = self.prompt_module.cls_to_task_id
        task_ids_preds = []
        for i in class_ids:
            cls = classes[i]
            if cls in dict_saved_classes or cls.replace(' ', '') in dict_saved_classes:
                cls = cls if cls in dict_saved_classes else cls.replace(' ', '')
                task_ids_preds.append(dict_saved_classes[cls])
        task_ids_preds = torch.LongTensor(task_ids_preds).to(self.device)  

        return task_ids_preds


    def encode_videos(self, videos, modulate = False, type_task = None, only_old_temporal=False, is_train=True, end_train=False, classes=None, is_use_imp_reg=False):
        videos = videos.to(self.device) 
        videos = videos.view((-1, 3) + videos.size()[-2:])  
            
        if self.type_mod == 'Prompt' and modulate and type_task != None:
            videos_features = self.clip_model.encode_image(videos, promptModule = self.prompt_module, type_task = type_task)
        else:
            videos_features = self.clip_model.encode_image(videos) 
            if self.Adapter['use_only_clip'] and end_train: 
                if self.Adapter['use_clip_and adapter']:
                    clip_feature = videos_features
                else:
                    clip_feature = self.clip_model.encode_image(videos, self.Adapter['use_only_clip'])
                clip_feature = clip_feature.view(clip_feature.size(0) // self.num_segments, self.num_segments, -1)
        videos_features = videos_features.view(videos_features.size(0) // self.num_segments, self.num_segments, -1) 
      
        if is_use_imp_reg:
            if self.is_cosin_rounter:
                context_emb = self.weighted_feature_sum(videos_features)
            else:
                context_emb = torch.mean(videos_features, dim=1)
            context_emb = context_emb / context_emb.norm(dim=-1, keepdim=True)
            if self.is_use_diff_feat and self.enable_temporal_module:
                context_emb_trans, _, _, _ = self.temporal_module(videos_features, is_pos_distill=self.is_pos_distill)
                context_emb_trans = context_emb_trans/ context_emb_trans.norm(dim=-1, keepdim=True)
                diff_feat = context_emb_trans - context_emb
                diff_feat = diff_feat / diff_feat.norm(dim=-1, keepdim=True)
                return context_emb, diff_feat
               
            else:
                return context_emb

        if self.enable_temporal_module:
          
            if self.is_use_MoE and self.task_count > 1 and not is_train and end_train:
                if self.is_save_model:
                    all_features = None
                    cnt = 0                     
                    if self.Adapter['use_only_clip'] and self.Adapter['is_use_clip_router']:
                        w_model = []
                        w_diff = []
                        esp = 1e-8
                        if self.is_cosin_rounter:
                            clip_feature = self.weighted_feature_sum(clip_feature)
                        else:
                            clip_feature = torch.mean(clip_feature, dim=1)
                        clip_feature = clip_feature/clip_feature.norm(dim=-1, keepdim=True)
                        if isinstance(classes, list):
                            pass
                        else:
                            classes = list(classes)
                        classes_emb = self.encode_labels(classes) 

                        classes_emb = classes_emb / classes_emb.norm(dim=-1, keepdim=True)
                        clip_feature = clip_feature.unsqueeze(dim=1)
                        clip_feature = clip_feature.expand(clip_feature.size(0), classes_emb.size(0), clip_feature.size(2)) 
                        classes_emb = classes_emb.unsqueeze(dim=0)
                        classes_emb = classes_emb.expand(clip_feature.size(0), classes_emb.size(1), classes_emb.size(2))
                        output = F.cosine_similarity(clip_feature, classes_emb, dim=2)  

                        if self.setting['is_use_half']:
                            if self.setting['num_classes'] == 174:
                                task_1 = 84
                            else:
                                task_1 = math.ceil(self.setting['num_classes']/2)
                            num_class_per_task = int((self.setting['num_classes'] - task_1) / (self.num_tasks-1))
                        else:
                            extral_num_class = self.setting['num_classes'] % self.num_tasks
                            num_class_per_task = int(self.setting['num_classes'] / self.num_tasks)
                            task_1 = int(num_class_per_task + extral_num_class)
                                

                        if self.is_use_diff_feat:
                            batch_size, feature_size = videos_features.size(0), videos_features.size(2)
                           
                            mean_cls_diff = torch.stack(self.mean_diff_feat, dim=0)

                            mean_cls_diff = mean_cls_diff.unsqueeze(dim=0)
                            mean_cls_diff= mean_cls_diff.expand(clip_feature.size(0), mean_cls_diff.size(1), mean_cls_diff.size(2))
                            mean_cls_diff = mean_cls_diff.to(self.device)
                            count = 0
                            start_class_diff = task_1
                            if self.is_cosin_rounter:
                                context_emb_ = self.weighted_feature_sum(videos_features)
                            else:
                                context_emb_ = torch.mean(videos_features, dim=1)
                            context_emb_ = context_emb_ / context_emb_.norm(dim=-1, keepdim=True)
                          
                            for i, model in enumerate(self.all_model):
                                count += 1
                              
                                model.eval()
                                feature, pos_embed, position_logits, perm = model(videos_features)
                                feature = feature / feature.norm(dim=-1, keepdim=True)
                                cls_diff = feature - context_emb_
                                cls_diff = cls_diff / cls_diff.norm(dim=-1, keepdim=True)
                                cls_diff = cls_diff.unsqueeze(dim=1)
                                cls_diff = cls_diff.expand(cls_diff.size(0), mean_cls_diff.size(1), cls_diff.size(2))  # (50,classes_num,512)
                                cls_diff =cls_diff.to(self.device)
                                output_diff = F.cosine_similarity(cls_diff, mean_cls_diff, dim=2)
                                if count == 1:
                                    first_task_similarity_diff = output_diff[:, :task_1]
                                    max_similarity_first_task_diff = torch.max(first_task_similarity_diff, dim=1)[0]
                                    w_diff.append(max_similarity_first_task_diff)
                                    
                                else:
                                    end_class_diff = start_class_diff + num_class_per_task
                                    task_similarity_diff = output_diff[:, start_class_diff:end_class_diff]
                                    max_similarity_task_diff = torch.max(task_similarity_diff, dim=1)[0]
                                    w_diff.append(max_similarity_task_diff)
                                    start_class_diff = end_class_diff
                               
                            count = 0
                          
                        first_task_similarity = output[:, :task_1]
                        max_similarity_first_task = torch.max(first_task_similarity, dim=1)[0]
                        w_model.append(max_similarity_first_task)


                        start_class = task_1
                        if self.task_count-1 > 1:
                            for task_id in range(1, self.task_count-1):
                                end_class = start_class + num_class_per_task
                                if task_id == self.task_count-2:
                                    end_class = len(classes)
                                task_similarity = output[:, start_class:end_class]
                                max_similarity_task = torch.max(task_similarity, dim=1)[0]
                                w_model.append(max_similarity_task)
                                start_class = end_class

                        w_model = torch.stack(w_model, dim=1)  
                   
                        w_model_normalized = (w_model-torch.min(w_model)) / (torch.max(w_model)-torch.min(w_model) + esp)  
                        if self.is_use_diff_feat:
                            w_diff = torch.stack(w_diff, dim=1) 
                            

                            w_diff_normalized = (w_diff-torch.min(w_diff)) / (torch.max(w_diff)-torch.min(w_diff) + esp)

                            if not self.Adapter['use_clip_and adapter']:
                                w_model_normalized = w_diff_normalized

                    for i, model in enumerate(self.all_model):
                        if self.Adapter['use_only_clip'] and self.Adapter['is_use_clip_router']:
                            w = w_model_normalized[:,i:i+1]
                        else:
                            w = 1
                        model.eval()
                        feature, pos_embed, position_logits, perm = model(videos_features)
                        feature = feature / feature.norm(dim=-1, keepdim=True)
                        if all_features == None:
                            all_features = feature*w
                        else:
                            all_features += feature*w
                        cnt += w
                    if self.is_cosin_rounter:
                        adapter_emb = self.weighted_feature_sum(videos_features)
                    else:
                        adapter_emb = torch.mean(videos_features, dim=1)
                       
                    adapter_emb = adapter_emb/adapter_emb.norm(dim=-1, keepdim=True) 
                    w_trans = 1 
                    w_adapter = 1 
                    context_emb = (w_trans*(all_features/cnt) + w_adapter*adapter_emb) /  (w_adapter + w_trans)

                else:
                    context_emb, pos_embed, position_logits, perm = self.avg_model(videos_features)
            elif self.type_mod == 'Prompt' and modulate and type_task != None:
                context_emb, pos_embed, position_logits, perm = self.temporal_module(videos_features, promptModule = self.prompt_module, type_task = type_task)
            else:
              
                context_emb, pos_embed, position_logits, perm = self.temporal_module(videos_features, is_pos_distill=self.is_pos_distill)  
                if self.is_use_adapter:
                    return context_emb, pos_embed, position_logits, perm, videos_features
        else:
            if self.is_select:
                video_select = self.select_topk_frames(videos_features, topk=self.num_valid_frames)
                context_emb = torch.mean(video_select,dim=1)  
            elif self.is_use_shape:
                mean_model = DummyModel()
                feature_selector = VideoFeatureSelection(mean_model, topk=self.num_valid_frames, device=self.device)
                final_context_embedding = feature_selector.get_final_context_embedding(video_emb)
                video_emb = torch.mean(final_context_embedding, dim=1)
            elif self.is_weight_mean:
                context_emb = self.compute_video_feature(videos_features)
            else:
                context_emb = torch.mean(videos_features, dim=1)
            return context_emb
        return context_emb, pos_embed, position_logits, perm 
    
    
    def encode_labels(self, textual_descrips, modulate = False, type_task = None):  
        if self.use_mean_text_template and self.use_text_template:
            formatted_texts = [[template.format(cls_txt.lower()) for template in self.template_dict] for cls_txt in textual_descrips]
            num_classes = len(formatted_texts)
            num_templates = len(formatted_texts[0])
            flat_texts = [text for texts in formatted_texts for text in texts]
            text_tokens = clip.tokenize(flat_texts).to(self.device)
            if self.type_mod == 'Prompt' and modulate and type_task != None:
                text_features = self.clip_model.encode_text(text_tokens, promptModule = self.prompt_module, type_task = type_task)
            else:
                text_features = self.clip_model.encode_text(text_tokens)

            text_features = text_features.view(num_classes, num_templates, -1)
            text_features = text_features.mean(dim=1) 
        else:
            if self.use_text_template:
                textual_descrips = ['A good photo of a '+cls_txt.lower() for cls_txt in textual_descrips]  
            text_tokens = clip.tokenize(textual_descrips).to(self.device)
            if self.type_mod == 'Prompt' and modulate and type_task != None:
                text_features = self.clip_model.encode_text(text_tokens, promptModule = self.prompt_module, type_task = type_task)
            else:
                text_features = self.clip_model.encode_text(text_tokens)
        return text_features 
    
    def encode_Tem_Prompts(self, modulate = False, type_task = None): 
        ordinal_numbers = ["first", "second", "third", "fourth", "fifth", "sixth", "seventh", "eighth", "ninth", "tenth"]
        if self.is_tem_prompts:
            tem_prompt = [f"The {ordinal_numbers[i]} frame of the video" for i in range(self.num_segments)] 
        text_tokens = clip.tokenize(tem_prompt).to(self.device)  
        if self.type_mod == 'Prompt' and modulate and type_task != None:
            text_features = self.clip_model.encode_text(text_tokens, promptModule = self.prompt_module, type_task = type_task)
        else:
            text_features = self.clip_model.encode_text(text_tokens)  
        return text_features 

    def set_train_mode(self, task_id, training_phase, pre_pro_train_mode, curr_pro_train_mode):
        if self.num_training_phases == 1: 
            if self.type_mod == 'Prompt':
                if self.type_prompt != 'general':
                    if self.prompt_module.L_sp_tk > 0:
                        self.prompt_module.set_train_mode_task_prompts(task_id, 'ViT', pre_pro_train_mode, curr_pro_train_mode)
                    if self.prompt_module.L_tp_tk > 0:
                        self.prompt_module.set_train_mode_task_prompts(task_id, 'temp', pre_pro_train_mode, curr_pro_train_mode)
                else:
                    self.prompt_module.train()
            if self.enable_temporal_module:
                self.temporal_module.train()  
        else:
            if training_phase == 1 or training_phase == 3: 
                if self.type_mod == 'Prompt':
                    if self.type_prompt != 'general':
                        if self.prompt_module.L_sp_tk > 0:
                            self.prompt_module.set_train_mode_task_prompts(task_id, 'ViT', False, False)
                        if self.prompt_module.L_tp_tk > 0:
                            self.prompt_module.set_train_mode_task_prompts(task_id, 'temp', False, False)
                    else:
                        self.prompt_module.eval()
                if self.enable_temporal_module:
                    self.temporal_module.train()
            else:
                if self.type_mod == 'Prompt':
                    if self.type_prompt != 'general':
                        if self.prompt_module.L_sp_tk > 0:
                            self.prompt_module.set_train_mode_task_prompts(task_id, 'ViT', pre_pro_train_mode, curr_pro_train_mode)
                        if self.prompt_module.L_tp_tk > 0:
                            self.prompt_module.set_train_mode_task_prompts(task_id, 'temp', pre_pro_train_mode, curr_pro_train_mode)
                    else:
                        self.prompt_module.train()
                if self.enable_temporal_module:
                    self.temporal_module.eval()

        if self.type_cls == 'Linear':
            self.classifier.train()
        if self.is_use_adapter:
            self.clip_model.eval()
            for block in self.clip_model.visual.transformer.resblocks:
                block.adapter.train()
        else:
            self.clip_model.eval()


    def set_eval_mode(self):
        if self.type_mod == 'Prompt':
            self.prompt_module.eval()
        if self.type_cls == 'Linear':
            self.classifier.eval()
        if self.enable_temporal_module:
            self.temporal_module.eval()
        
        self.clip_model.eval()
        
    def count_accuracy(self, video_emb, classes, labels, modulate_txt):
        if isinstance(classes, list):
            pass
        else:
            classes = list(classes) 
        # print(classes)
        classes_emb = self.encode_labels(classes, modulate_txt) 

        # norm
        classes_emb = classes_emb / classes_emb.norm(dim=-1, keepdim=True)

        video_emb = video_emb.unsqueeze(dim=1)
        video_emb = video_emb.expand(video_emb.size(0), classes_emb.size(0), video_emb.size(2))  
        classes_emb = classes_emb.unsqueeze(dim=0)
        classes_emb = classes_emb.expand(video_emb.size(0), classes_emb.size(1), classes_emb.size(2))
        output = F.cosine_similarity(video_emb, classes_emb, dim=2)  
        acc = accuracy(output.data, labels, topk=(1,))[0]
        acc = acc.item()
        return acc

    def validate_task(self, val_cilDatasetList, current_task_id, modulate_vid, modulate_txt, classes, training_phase, is_final, type_val):
        print('Init val task')
        top1 = AverageMeter()
        val_loader, _ = val_cilDatasetList.get_validation_task(current_task_id, curr_classes=True)
        self.set_eval_mode()
        BWF = AverageMeter()
        with torch.no_grad():
            for _, _, videos, _, labels, text_descrip in val_loader:
                labels = labels.to(self.device)
                # compute output
                with autocast():
                    acc_val, _, _, _, _, _, _ = self.forward_pass_video(False, modulate_vid, modulate_txt, videos, labels, text_descrip, classes, training_phase, is_train=False)
                    
                top1.update(acc_val, videos.size(0))
            
            if is_final and current_task_id == 0:
                self.experiment.log_metric("Acc_task_{}".format(current_task_id+1), top1.avg, step=current_task_id+1)
                self.list_val_acc_ii[type_val].append(top1.avg)

                self.experiment.log_metric("Total_Acc_Per_task", top1.avg, step=current_task_id+1)
                self.experiment.log_metric("Total_BWF_Per_task", BWF.avg, step=current_task_id+1)   

        return top1.avg, None
    
    def famework_validation_task(self, val_cilDatasetList, current_task_id, classes, training_phase, type_val = 'val', is_final= False, modulate_vid= False, modulate_txt = False, split_batch = False):
        if self.num_training_phases == 1 or (self.num_training_phases > 1 and training_phase == 3):
            print('here total val, training phase: ',training_phase)
            return self.validate(val_cilDatasetList, current_task_id, classes, training_phase, type_val, is_final, modulate_vid, modulate_txt, split_batch)
        else:
            print('here val task, training phase: ',training_phase)
            return self.validate_task(val_cilDatasetList, current_task_id, modulate_vid, modulate_txt, classes, training_phase, is_final, type_val)


    def validate(self, val_cilDatasetList, current_task_id, classes, training_phase, type_val = 'val', is_final= False, modulate_vid= False, modulate_txt = False, split_batch = False, end_train=False):
        print('Init val')
        top1 = AverageMeter()
        total_acc = AverageMeter()
        top1_aux = AverageMeter()
        total_acc_aux = AverageMeter()
        val_loaders_list, _ = val_cilDatasetList.get_valSet_by_taskNum(current_task_id+1)
        BWF = AverageMeter()
        # switch to evaluate mode
        self.set_eval_mode()
        
        with torch.no_grad():
            for n_task, (val_loader, num_classes) in enumerate(val_loaders_list):
                for _, _, videos, _, labels, text_descrip in val_loader:
                    labels = labels.to(self.device)
                    # compute output
                    with autocast():
                        acc_val, _, aux_acc_val, _, _, _, _ = self.forward_pass_video(split_batch, modulate_vid, modulate_txt, videos, labels, text_descrip, classes, training_phase, is_train=False, end_train=end_train)
                    top1.update(acc_val, videos.size(0))
                    if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector and len(self.prompt_module.cls_to_task_id)>0:
                        top1_aux.update(aux_acc_val, videos.size(0))
                
                total_acc.update(top1.avg, num_classes)
                print('Train... task : {}, acc with classifier: {}'.format(n_task, top1.avg))
                if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector and len(self.prompt_module.cls_to_task_id)>0:
                    total_acc_aux.update(top1_aux.avg, num_classes)
                    print('Train... task : {}, aux acc with classifier: {}'.format(n_task, top1_aux.avg))
                if is_final:
                    self.experiment.log_metric("Acc_task_{}".format(n_task+1), top1.avg, step=current_task_id+1)
                    if n_task == current_task_id:
                        self.list_val_acc_ii[type_val].append(top1.avg)
                    elif n_task < current_task_id:
                        forgetting = self.list_val_acc_ii[type_val][n_task] - top1.avg
                        BWF.update(forgetting, num_classes)
                
                top1.reset()
                if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector and len(self.prompt_module.cls_to_task_id)>0:
                    top1_aux.reset()

        output = ('Pre Testing Results: Pre_Acc {total_acc.avg:.3f}'
                  .format(total_acc=total_acc))
        
        if is_final:
            self.experiment.log_metric("Total_Acc_Per_task", total_acc.avg, step=current_task_id+1)
            self.experiment.log_metric("Total_BWF_Per_task", BWF.avg, step=current_task_id+1)
        print(output)

        avg_total_acc_aux = None
        if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector and len(self.prompt_module.cls_to_task_id)>0:
            print('Pre Testing Results: Pre_Acc_AUX {total_acc_aux.avg:.3f}'.format(total_acc_aux=total_acc_aux))
            avg_total_acc_aux = total_acc_aux.avg
        return total_acc.avg, avg_total_acc_aux

    def load_best_checkpoint(self, path_model, current_task):
        if os.path.exists(path_model):
            checkpoint_dict = torch.load(path_model)
            task_to_load = checkpoint_dict['current_task']
            epoch_to_load = checkpoint_dict['current_epoch']
            if task_to_load == current_task:
                print('Loading best checkpoint ... epoch: {}, task: {}'.format(epoch_to_load, task_to_load))
                if self.type_mod == 'Prompt':
                    self.prompt_module.load_state_dict(checkpoint_dict['state_dict_prompt'])
                if self.type_cls == 'Linear':
                    self.classifier.load_state_dict(checkpoint_dict['state_dict_classifier'])
                if self.enable_temporal_module:
                    self.temporal_module.load_state_dict(checkpoint_dict['state_dict_temporal_module'])

    def get_task_label(self, text_descriptions):
        dict_saved_classes = self.prompt_module.cls_to_task_id
        task_ids = []
        for cls in text_descriptions:
            if cls in dict_saved_classes or cls.replace(' ', '') in dict_saved_classes:
                cls = cls if cls in dict_saved_classes else cls.replace(' ', '')
                task_ids.append(dict_saved_classes[cls]) 
        task_ids = torch.LongTensor(task_ids).to(self.device)
        return task_ids

    def forward_pass_video(self, split_batch, modulate_vid, modulate_txt, videos, labels, text_descrip, classes, training_phase, curr_task_id=None, is_train = False, end_train=False):
        if split_batch and self.type_mod == 'Prompt' and len(self.prompt_module.cls_to_task_id)>0: 
            batch_to_prompt, label_prompt, text_decrip_prompt, batch_novel_cls, label_cls, text_descrip_cls = self.split_batch(videos, labels, text_descrip, curr_task_id)
            if batch_to_prompt != None:
                if self.teacher_forcing or self.type_task == 'TIL' or self.type_prompt == 'general':
                    self.prompt_module.set_current_video_classes(text_decrip_prompt)
                    video_emb_prompt = self.encode_videos(batch_to_prompt, True, 'TIL')
                else:
                    task_ids_preds = self.forward_selector(batch_to_prompt, classes, modulate_txt)
                    self.prompt_module.set_current_task_ids(task_ids_preds)  
                    video_emb_prompt = self.encode_videos(batch_to_prompt, True, 'CIL') 
                video_emb_prompt_wo_mod = self.encode_videos(batch_to_prompt, False)  
                video_emb_prompt = torch.cat([video_emb_prompt, video_emb_prompt_wo_mod], dim=0) 
                video_emb = video_emb_prompt
                label_prompt = torch.cat([label_prompt, label_prompt], dim=0)
                labels = label_prompt
                text_decrip_prompt.extend(text_decrip_prompt)
                text_descrip = text_decrip_prompt
                batch_to_prompt = torch.cat([batch_to_prompt, batch_to_prompt], dim=0)
                videos = batch_to_prompt
            if batch_novel_cls != None:
                video_emb_cls = self.encode_videos(batch_novel_cls, False)
                video_emb = video_emb_cls
            if batch_novel_cls != None and batch_to_prompt != None:
                video_emb = torch.cat([video_emb_prompt, video_emb_cls], dim = 0)
                labels = torch.cat([label_prompt, label_cls], dim = 0)
                text_decrip_prompt.extend(text_descrip_cls)
                idx = list(range(video_emb.size(0)))
                random.shuffle(idx)   
                text_descrip = [text_decrip_prompt[i] for i in idx]
                idx = torch.LongTensor(idx)
                video_emb = video_emb[idx]
                labels = labels[idx] 

                batch_videos = torch.cat([batch_to_prompt, batch_novel_cls], dim = 0)
                videos = batch_videos[idx]
        else:
            if self.type_mod == 'Prompt':
                enable_cil = self.num_training_phases == 1 or (self.num_training_phases > 1 and training_phase == 3)  # false
                if self.type_task == 'CIL' and len(self.prompt_module.cls_to_task_id)>0 and enable_cil:
                    if modulate_vid and self.type_prompt != 'general':
                        task_ids_pred = self.forward_selector(videos, classes, modulate_txt)
                        self.prompt_module.set_current_task_ids(task_ids_pred)
                    video_emb, pos_embed, position_logits = self.encode_videos(videos, modulate_vid, 'CIL', is_train=is_train)
                else:
                    self.prompt_module.set_current_video_classes(text_descrip)
                    video_emb, pos_embed, position_logits = self.encode_videos(videos, modulate_vid, 'TIL', is_train=is_train)  
            elif self.enable_temporal_module:
                if self.is_tem_prompts:
                    # tem_prompts = self.encode_Tem_Prompts()
                    video_emb, pos_embed, position_logits, perm = self.encode_videos(videos, modulate_vid, is_train=is_train)
                    loss_Hungarian, pos_acc = self.HungarianLoss(position_logits, perm)
                    print(pos_acc)
                else:
                    if self.is_use_adapter and not end_train:
                        video_emb, _, _, _, video_feature = self.encode_videos(videos, modulate_vid, is_train=is_train, end_train=end_train, classes=classes)
                        if self.is_cosin_rounter:
                            video_feature = self.weighted_feature_sum(video_feature)
                        else:
                            video_feature = torch.mean(video_feature, dim=1)
                        video_feature = video_feature / video_feature.norm(dim=-1, keepdim=True)
                        if self.is_use_imp_reg and not end_train:
                            WDLoss = 0
                            if self.task_count > 1:
                                with torch.no_grad():
                                    videos_ = videos.to(self.device)  
                                    videos_ = videos_.view((-1, 3) + videos_.size()[-2:])
                                    video_old = self.old_backbone.encode_image(videos_).detach()
                                    del videos_
                                    video_old = video_old.view(video_old.size(0) // self.num_segments, self.num_segments, -1)
                                    if self.is_cosin_rounter:
                                        video_old = self.weighted_feature_sum(video_old)
                                    else:
                                        video_old = torch.mean(video_old, dim=1)
                                    video_old = video_old / video_old.norm(dim=-1, keepdim=True)
                                    weight_distill = (self.weight_distill - torch.min(self.weight_distill))/(torch.max(self.weight_distill)-torch.min(self.weight_distill) + self.eps)
                                    weight_distill = abs(weight_distill.to(self.device))
                                WDLoss = (weight_distill * torch.mean(((video_feature - video_old)**2), dim=0)).mean()
                        else:
                            WDLoss = 0
                    else:
                        WDLoss = 0
                        if self.Adapter['is_use_clip_router']:
                            video_emb, pos_embed, position_logits, _ = self.encode_videos(videos, modulate_vid, is_train=is_train, end_train=end_train, classes=classes)
                        else:
                            video_emb, pos_embed, position_logits, _ = self.encode_videos(videos, modulate_vid, is_train=is_train, end_train=end_train)
                    loss_Hungarian = 0
                    pos_acc = 0
            else:  # only clip or  only_adapter
                loss_Hungarian = 0
                video_emb = self.encode_videos(videos, modulate_vid, is_train=is_train)
                video_feature = video_emb
            if self.is_pos_distill and self.old_temporal_module is not None:
                old_video_emb, old_pos_embed, _ = self.encode_videos(videos, modulate_vid, only_old_temporal=True)
            else:
                old_video_emb, old_pos_embed = None, None
        
        aux_acc_train = None
        if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector and len(self.prompt_module.cls_to_task_id)>0:
            gt_tasks_ids = self.get_task_label(text_descrip)
            task_ids_pred = self.forward_selector(videos, classes, modulate_txt)
            aux_acc_train = torch.eq(task_ids_pred, gt_tasks_ids)
            aux_acc_train = (aux_acc_train.float().sum(0))*100/gt_tasks_ids.size(0)
            aux_acc_train = aux_acc_train.item()

        if self.type_cls == 'Linear':
            preds = self.classifier(video_emb)
            loss_cls = self.cls_loss(preds, labels)
            acc_train = accuracy(preds.data, labels, topk=(1,))[0]
            acc_train = acc_train.item()

            
            if self.num_training_phases == 1:
                if self.is_use_adapter and not end_train and not self.is_only_use_adapter:
                    loss_adapter_contrast = self.cls_loss(video_feature, labels)
                else:
                    loss_adapter_contrast  = 0
                loss = loss_cls + loss_adapter_contrast + w_WDLoss*WDLoss
            else:
                ground_truth = torch.arange(len(videos),dtype=torch.long,device=self.device)  
                loss = (self.cls_loss(logits_per_image,ground_truth) + self.cls_loss(logits_per_text,ground_truth))/2 + loss_pos_distill + w_Hungarian*loss_Hungarian + w_WDLoss*WDLoss
            



        else:
            text_emb = self.encode_labels(text_descrip, modulate_txt)  
            # normalized features
            # if self.is_save_model and not end_train:
            video_emb = video_emb / video_emb.norm(dim=-1, keepdim=True)  
            if self.is_only_use_adapter:
                if self.is_use_imp_reg and not end_train:
                    WDLoss = 0
                    if self.task_count > 1:
                        with torch.no_grad():
                            videos_ = videos.to(self.device) 
                            videos_ = videos_.view((-1, 3) + videos_.size()[-2:])
                            video_old = self.old_backbone.encode_image(videos_).detach()
                            del videos_
                            video_old = video_old.view(video_old.size(0) // self.num_segments, self.num_segments, -1)
                            if self.is_cosin_rounter:
                                video_old = self.weighted_feature_sum(video_old)
                            else:
                                video_old = torch.mean(video_old, dim=1)
                            video_old = video_old / video_old.norm(dim=-1, keepdim=True)
                            weight_distill = (self.weight_distill - torch.min(self.weight_distill))/(torch.max(self.weight_distill)-torch.min(self.weight_distill) + self.eps)
                            weight_distill = abs(weight_distill.to(self.device))
                        WDLoss = (weight_distill * torch.mean(((video_emb - video_old)**2), dim=0)).mean()
                else:
                    WDLoss = 0
            elif not self.is_use_adapter:
                WDLoss = 0

            text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True) 

            if self.is_pos_distill and self.old_temporal_module is not None and is_train:
                w_distill = 1e2
                old_pos_embed = old_pos_embed / old_pos_embed.norm(dim=-1, keepdim=True)
                pos_embed = pos_embed / pos_embed.norm(dim=-1, keepdim=True)
                loss_pos_distill = F.mse_loss(old_pos_embed, pos_embed) * w_distill
            else:
                loss_pos_distill = 0
            logit_scale = 20 
            w_Hungarian = 100
            w_WDLoss = 2e6 
            if self.is_use_adapter and not end_train and not self.is_only_use_adapter:
                logit_adapter_scale = 50 
                logits_adapter_imge = logit_adapter_scale * video_feature @ text_emb.t() 
                logits_adapter_text = logits_adapter_imge.t()
            logits_per_image = logit_scale * video_emb @ text_emb.t() 
            logits_per_text = logits_per_image.t()
            
            if self.num_training_phases == 1:
                loss_contrast = self.ContrastLoss(labels, logits_per_image, logits_per_text)
                if self.is_use_adapter and not end_train and not self.is_only_use_adapter:
                    loss_adapter_contrast = self.ContrastLoss(labels, logits_adapter_imge, logits_adapter_text)
                else:
                    loss_adapter_contrast  = 0
                
                loss = loss_contrast + loss_pos_distill + w_Hungarian*loss_Hungarian + loss_adapter_contrast + w_WDLoss*WDLoss
            else:
                ground_truth = torch.arange(len(videos),dtype=torch.long,device=self.device) 
                loss = (self.cls_loss(logits_per_image,ground_truth) + self.cls_loss(logits_per_text,ground_truth))/2 + loss_pos_distill + w_Hungarian*loss_Hungarian + w_WDLoss*WDLoss

            
            acc_train = self.count_accuracy(video_emb, classes, labels, modulate_txt)  # int

        return acc_train, loss, aux_acc_train, loss_pos_distill, w_Hungarian*loss_Hungarian, loss_adapter_contrast, w_WDLoss*WDLoss

    def split_batch(self, x, labels, text_description, curr_task_id):
        batch_to_prompt, label_prompt, text_decrip_prompt = [], [], []
        batch_novel_cls, label_cls, text_descrip_cls = [], [], []
        dict_saved_classes = self.prompt_module.cls_to_task_id
        for i in range(len(text_description)):
            cls = text_description[i]
            if cls in dict_saved_classes or cls.replace(' ', '') in dict_saved_classes:
                batch_to_prompt.append(x[i])
                label_prompt.append(labels[i])
                text_decrip_prompt.append(text_description[i])
            else:
                batch_novel_cls.append(x[i])
                label_cls.append(labels[i])
                text_descrip_cls.append(text_description[i])

        batch_to_prompt = torch.stack(batch_to_prompt, dim=0) if len(batch_to_prompt) > 0 else None
        label_prompt = torch.stack(label_prompt, dim=0) if len(label_prompt) > 0 else None

        batch_novel_cls = torch.stack(batch_novel_cls, dim=0) if len(batch_novel_cls) > 0 else None
        label_cls = torch.stack(label_cls, dim=0) if len(label_cls) > 0 else None
        return batch_to_prompt, label_prompt, text_decrip_prompt, batch_novel_cls, label_cls, text_descrip_cls
    
    def save_checkpoint(self, path_model, acc_val, epoch, task_id, is_best):
        if is_best and (self.type_mod == 'Prompt' or self.type_cls == 'Linear'):
            print('Saving ... ')
            dict_to_save = {'accuracy': acc_val, 'current_epoch': epoch, 
                            'current_task': task_id, 'optimizer': self.optimizer.state_dict()}
            if self.type_mod == 'Prompt':
                dict_to_save['state_dict_prompt'] = self.prompt_module.state_dict()
            if self.type_cls == 'Linear':
                dict_to_save['state_dict_classifier'] = self.classifier.state_dict()
            if self.enable_temporal_module:
                dict_to_save['state_dict_temporal_module'] = self.temporal_module.state_dict()
            torch.save(dict_to_save, path_model)
            print("Save Best Networks for task: {}, epoch: {}".format(dict_to_save['current_task'] + 1, 
                                                                 dict_to_save['current_epoch'] + 1), flush=True)
    
    def train_phase(self, task_id, training_phase, pre_pro_train_mode, curr_pro_train_mode, train_dataloader_cil, val_cilDatasetList, modulate_vid, modulate_txt, split_batch, is_final):

        eval_freq = self.conf['checkpoints']['eval_freq']
        path_model = self.conf['checkpoints']['path_model']
        num_epochs = self.conf['epochs']
        if task_id == 0 and self.setting['is_use_half']:
            num_epochs = self.conf['epochs']
        best_acc_val = 0
        
        self.prepare_trainining(train_dataloader_cil.dataset.classes, task_id, training_phase, pre_pro_train_mode, curr_pro_train_mode)
        self.get_optimizer(training_phase)
        self.optimizer.zero_grad()
        with self.experiment.train():
            if self.classes is not None:         
                for value in list(train_dataloader_cil.dataset.classes):
                    self.classes.append(value)
            for epoch in range(num_epochs):
                if self.classes is None:
                    self.classes = list(train_dataloader_cil.dataset.classes)
                self.set_train_mode(task_id, training_phase, pre_pro_train_mode, curr_pro_train_mode)
                acc_Avg = AverageMeter()
                loss_Avg = AverageMeter()
                aux_acc_Avg = AverageMeter()
                loss_pos_distill_Avg = AverageMeter()
                loss_Hungarian_Avg = AverageMeter()
                loss_adapter_contrast_Avg = AverageMeter()
                loss_WD = AverageMeter()
                for i, (indices, _, videos, _, labels, text_descrip) in enumerate(train_dataloader_cil): 
                    labels = labels.to(self.device)
                    self.optimizer.zero_grad()
                    with autocast():  
                        acc_train, loss, aux_acc_train, loss_pos_distill, loss_Hungarian, loss_adapter_contrast, WDLoss = self.forward_pass_video(split_batch, modulate_vid, modulate_txt, videos, labels, text_descrip, self.classes, training_phase, curr_task_id=task_id, is_train=True)
                    if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector and len(self.prompt_module.cls_to_task_id)>0:
                        aux_acc_Avg.update(aux_acc_train, videos.size(0))

                    loss.backward()
                    self.optimizer.step()
                    
                    loss_Avg.update(loss.item(), videos.size(0))
                    acc_Avg.update(acc_train, videos.size(0))
                    if self.is_pos_distill and loss_pos_distill > 0:
                        loss_pos_distill_Avg.update(loss_pos_distill.item(), 1)
                    if self.is_auxiliary_training:
                        loss_Hungarian_Avg.update(loss_Hungarian.item(), 1)
                    if self.is_use_adapter and not self.is_only_use_adapter:
                        loss_adapter_contrast_Avg.update(loss_adapter_contrast.item(), 1)
                    if self.is_use_imp_reg and self.task_count>1:
                        loss_WD.update(WDLoss.item(), 1)

                    if (i+1) % 2 == 0: 
                        if self.is_pos_distill and loss_pos_distill > 0:
                            print('Epoch [%d/%d], Loss: %.4f, pos_distill_Loss: %.4f' 
                                %(epoch+1, num_epochs, loss.item(), loss_pos_distill.item()))
                        if self.is_auxiliary_training:
                            print('Epoch [%d/%d], Loss: %.4f, loss_Hungarian: %.4f' 
                                %(epoch+1, num_epochs, loss.item(), loss_Hungarian.item()))
                        if self.is_use_adapter:
                            if self.is_only_use_adapter:
                                if self.is_use_imp_reg and self.task_count>1:
                                    print('Epoch [%d/%d], Loss: %.4f, WDLoss: %.4f' 
                                        %(epoch+1, num_epochs, loss.item(), WDLoss.item()))
                                else:
                                    print('Epoch [%d/%d], Loss: %.4f' 
                                        %(epoch+1, num_epochs, loss.item()))
                            else:
                                if self.is_use_imp_reg and self.task_count>1:
                                    print('Epoch [%d/%d], Loss: %.4f, loss_adapter_contrast: %.4f, WDLoss: %.4f' 
                                        %(epoch+1, num_epochs, loss.item(), loss_adapter_contrast.item(), WDLoss.item()))
                                else:
                                    print('Epoch [%d/%d], Loss: %.4f, loss_adapter_contrast: %.4f' 
                                        %(epoch+1, num_epochs, loss.item(), loss_adapter_contrast.item()))
                        else:
                            print('Epoch [%d/%d], Loss: %.4f, pos_distill_Loss: %.4f' 
                                %(epoch+1, num_epochs, loss.item(), 0.0000))
                        
                self.experiment.log_metric("Epoch_Acc_task_{}".format(task_id+1), acc_Avg.avg)
                self.experiment.log_metric("Epoch_Loss_task_{}".format(task_id+1), loss_Avg.avg)
                if self.is_pos_distill and loss_pos_distill > 0:
                    self.experiment.log_metric("Epoch_pos_distll_Loss_task_{}".format(task_id+1), loss_pos_distill_Avg.avg)

                if self.is_auxiliary_training:
                    self.experiment.log_metric("Epoch_loss_Hungarian_task_{}".format(task_id+1), loss_Hungarian_Avg.avg)
                
                if self.is_use_adapter:
                    if self.is_only_use_adapter:
                        if self.is_use_imp_reg:
                            self.experiment.log_metric("Epoch_loss_adapter_contrast_task_{}".format(task_id+1), loss_WD.avg)
                    else:
                        if self.is_use_imp_reg:
                            self.experiment.log_metric("Epoch_loss_adapter_contrast_task_{}".format(task_id+1), loss_adapter_contrast_Avg.avg, loss_WD.avg)
                        else:
                            self.experiment.log_metric("Epoch_loss_adapter_contrast_task_{}".format(task_id+1), loss_adapter_contrast_Avg.avg)

                if self.type_mod == 'Prompt' and self.type_task == 'CIL' and self.type_prompt != 'general' and training_phase == self.training_phase_task_selector:
                    self.experiment.log_metric("Epoch_Aux_Acc_task_{}".format(task_id+1), aux_acc_Avg.avg)  # Aux_Acc=task_id acc

                if ((epoch + 1) % eval_freq == 0 or epoch == num_epochs - 1) and val_cilDatasetList is not None:
                    with self.experiment.validate():  
                       
                        classes = list(dict.fromkeys(self.classes))  
                        acc_val, aux_acc_val = self.famework_validation_task(val_cilDatasetList, task_id, classes, training_phase, 'val', False, modulate_vid, modulate_txt, False)
                        self.experiment.log_metric("Acc_at_task_{}".format(task_id+1), acc_val)
                        
            
            if self.is_use_imp_reg: 
                mean = None
                var = None
                with torch.no_grad():
                    self.old_backbone = deepcopy(self.clip_model)
                    classes = list(dict.fromkeys(self.classes)) 
                if isinstance(classes, list):
                    pass
                else:
                    classes = list(classes) 
                class_features = {}
                diff_class_feature = {}
                for i, (_, _, videos, _, labels, _) in enumerate(train_dataloader_cil): 
                    with torch.no_grad():
                        if self.is_use_diff_feat:
                            features, diff_feat = self.encode_videos(videos, is_use_imp_reg=self.is_use_imp_reg)
                            features = features.detach()
                            diff_feat = diff_feat.detach()
                            diff_feat = diff_feat.cpu()
                            features = features.cpu()
                        else:
                            features =  self.encode_videos(videos, is_use_imp_reg=self.is_use_imp_reg).detach()
                    features = features.cpu()
                    labels = labels.cpu()
                    batch = features.size(0)
                    for i in range(batch):
                        if labels[i].item() not in class_features:
                            class_features[labels[i].item()] = []
                        class_features[labels[i].item()].append(features[i]) 
                        if self.is_use_diff_feat:
                            if labels[i].item() not in diff_class_feature:
                                diff_class_feature[labels[i].item()] = []
                            diff_class_feature[labels[i].item()].append(diff_feat[i]) 
                        
                with torch.no_grad():
                    classes_emb = self.encode_labels(classes).detach()  
                classes_emb = classes_emb / classes_emb.norm(dim=-1, keepdim=True)
                classes_emb = classes_emb.cpu()
                for label, feat in class_features.items():
                    feat = torch.stack(feat) 
                    mean = feat.mean(dim=0) 
    
                    var = feat.var(dim=0, unbiased=False)
                    var = var / (self.eps + var.max())
                    mean = (mean * classes_emb[label])/(torch.norm(mean) * torch.norm(classes_emb[label]) + self.eps)

                    
                    weight = mean / (self.eps + var)
                    if self.weight_distill is None:
                        self.weight_distill = weight
                    else:
                        self.weight_distill += weight
                del class_features
                if self.is_use_diff_feat:
                    for label, feat in diff_class_feature.items():
                        feat = torch.stack(feat)
                        mean = feat.mean(dim=0)
                        self.mean_diff_feat.append(mean) 
                    del diff_class_feature

            if self.is_use_diff_feat:
                if self.is_use_imp_reg:
                    pass
                else:
                    diff_class_feature = {}
                    for i, (_, _, videos, _, labels, _) in enumerate(train_dataloader_cil): 
                        with torch.no_grad():
                            _, diff_feat = self.encode_videos(videos, is_use_imp_reg=True)
                            diff_feat = diff_feat.detach()
                        diff_feat = diff_feat.cpu()
                        labels = labels.cpu()
                        batch = diff_feat.size(0)
                        for i in range(batch):
                            if labels[i].item() not in diff_class_feature:
                                diff_class_feature[labels[i].item()] = []
                            diff_class_feature[labels[i].item()].append(diff_feat[i]) 
                    for label, feat in diff_class_feature.items():
                        feat = torch.stack(feat)
                        mean = feat.mean(dim=0)
                        self.mean_diff_feat.append(mean) 
                    del diff_class_feature

            if self.is_use_MoE:
                with torch.no_grad():
                    if self.task_count == 1:
                        if self.is_save_model:
                            self.all_model.append(deepcopy(self.temporal_module).to(self.device))
                           
                        else:
                            self.avg_model.load_state_dict(self.temporal_module.state_dict())  # init
                        self.task_count += 1
                        self.temporal_module = Temporal_Module(self.device, self.conf_feature_encoder).to(self.device)
                    else:
                        self.update_average()   # update self.avg_model
            else:
                self.task_count += 1

            # get t-1 temporal_module 
            if self.is_pos_distill:
                self.old_temporal_module =  deepcopy(self.temporal_module).to(self.device)
            else:
                self.old_temporal_module = None
    
    def train_task(self, task_id, train_dataloader_cil, train_dataloader_men, val_cilDatasetList, num_tasks=11):
        print('Task # {}'.format(task_id+1))
        self.num_tasks = num_tasks
        self.task_id = task_id
        if self.num_training_phases == 1:
            self.train_phase(task_id, 1, self.pre_pro_train_mode, self.curr_pro_train_mode, train_dataloader_cil, val_cilDatasetList, modulate_vid = True, modulate_txt=False, split_batch = False, is_final = True)
        else:
            return 0