


import torch
import torch.nn as nn
from . import transformer as tr
import numpy as np
from .models import META_ARCHITECTURES as registry
from .feature_head import build_feature_head
from .temporal_head import build_temporal_head
import random

class LSTR(nn.Module):

    def __init__(self, cfg):
        super(LSTR, self).__init__()

        
        self.long_memory_num_samples = cfg.MODEL.LSTR.LONG_MEMORY_NUM_SAMPLES
        self.long_enabled = self.long_memory_num_samples > 0
        if self.long_enabled:
            self.feature_head_long = build_feature_head(cfg)

        
        self.work_memory_num_samples = cfg.MODEL.LSTR.WORK_MEMORY_NUM_SAMPLES
        self.work_task_head=[]
        self.task_head=[]
        self.dropout1 = nn.Dropout(0.1)
        self.norm1 = nn.LayerNorm(1024)
        self.norm2=nn.LayerNorm(1024)
        self.work_enabled = self.work_memory_num_samples > 0
        if self.work_enabled:
            self.feature_head_work = build_feature_head(cfg)
        self.d_model = self.feature_head_work.d_model
        self.num_heads = cfg.MODEL.LSTR.NUM_HEADS
        self.dim_feedforward = cfg.MODEL.LSTR.DIM_FEEDFORWARD
        self.dropout = cfg.MODEL.LSTR.DROPOUT
        self.activation = cfg.MODEL.LSTR.ACTIVATION
        self.num_classes = cfg.DATA.NUM_CLASSES
        
        self.pos_encoding = tr.PositionalEncoding(self.d_model, self.dropout)
        self.multihead_attention_list=[]     
        self.work_multihead_attention_list = []  
        self.multihead_attention_weights=[]  
        self.param_wgt=np.zeros((cfg.DATA.NUM_CLASSES))
        self.work_param_wgt = np.zeros((cfg.DATA.NUM_CLASSES))
        self.acc=[]
        
        if self.long_enabled:
            self.enc_queries = nn.ModuleList()
            self.enc_modules = nn.ModuleList()
            for param in cfg.MODEL.LSTR.ENC_MODULE:
                if param[0] != -1:
                    self.enc_queries.append(nn.Embedding(param[0], self.d_model))
                    enc_layer = tr.TransformerDecoderLayer(
                        self.d_model, self.num_heads, self.dim_feedforward,
                        self.dropout, self.activation)
                    self.enc_modules.append(tr.TransformerDecoder(
                        enc_layer, param[1], tr.layer_norm(self.d_model, param[2])))
                else:
                    self.enc_queries.append(None)
                    enc_layer = tr.TransformerEncoderLayer(
                        self.d_model, self.num_heads, self.dim_feedforward,
                        self.dropout, self.activation)
                    self.enc_modules.append(tr.TransformerEncoder(
                        enc_layer, param[1], tr.layer_norm(self.d_model, param[2])))
        else:
            self.register_parameter('enc_queries', None)
            self.register_parameter('enc_modules', None)

        
        if self.long_enabled:
            param = cfg.MODEL.LSTR.DEC_MODULE
            dec_layer = tr.TransformerDecoderLayer(
                self.d_model, self.num_heads, self.dim_feedforward,
                self.dropout, self.activation)
            self.dec_modules = tr.TransformerDecoder(
                dec_layer, param[1], tr.layer_norm(self.d_model, param[2]))
        else:
            param = cfg.MODEL.LSTR.DEC_MODULE
            dec_layer = tr.TransformerEncoderLayer(
                self.d_model, self.num_heads, self.dim_feedforward,
                self.dropout, self.activation)
            self.dec_modules = tr.TransformerEncoder(
                dec_layer, param[1], tr.layer_norm(self.d_model, param[2]))

        
        self.classifier = nn.Linear(self.d_model, self.num_classes)
        
    def forward(self, visual_inputs, motion_inputs, memory_key_padding_mask=None):
        att_out=0
        work_att_out=0
        if self.long_enabled:
            
            long_memories = self.pos_encoding(self.feature_head_long( 
                visual_inputs[:, :self.long_memory_num_samples],     
                motion_inputs[:, :self.long_memory_num_samples],
            ).transpose(0, 1)) 
            if self.multihead_attention_list!=[]:
                tem_long_memories=long_memories.clone()
                for i,wgt in enumerate(self.multihead_attention_weights):
                    multihead_attention_layer=self.multihead_attention_list[i]
                    att_per_out, att_wgt = multihead_attention_layer(tem_long_memories, tem_long_memories,
                                                                 tem_long_memories) 
                    att_out =self.task_head[i](att_per_out)+att_out
                long_memories=long_memories+att_out
                long_memories=self.norm1(long_memories)
            
            if len(self.enc_modules) > 0:  
                enc_queries = [
                    enc_query.weight.unsqueeze(1).repeat(1, long_memories.shape[1], 1)
                    if enc_query is not None else None
                    for enc_query in self.enc_queries             
                ]
                
                
                
                if enc_queries[0] is not None:
                    long_memories = self.enc_modules[0](enc_queries[0], long_memories,memory_key_padding_mask=memory_key_padding_mask) 
                else:
                    long_memories = self.enc_modules[0](long_memories)
                for enc_query, enc_module in zip(enc_queries[1:], self.enc_modules[1:]):  
                    if enc_query is not None:
                        long_memories = enc_module(enc_query, long_memories)
                    else:
                        long_memories = enc_module(long_memories)

        
        if self.long_enabled:
            memory = long_memories 

        if self.work_enabled:
            
            work_memories = self.pos_encoding(self.feature_head_work(  
                visual_inputs[:, self.long_memory_num_samples:],
                motion_inputs[:, self.long_memory_num_samples:],
            ).transpose(0, 1), padding=self.long_memory_num_samples)  
            if self.work_multihead_attention_list != []:  
                tem_work_memories = work_memories.clone()
                for i, wgt in enumerate(self.multihead_attention_weights):
                    work_multihead_attention_layer = self.work_multihead_attention_list[i]
                    work_att_per_out, work_att_wgt = work_multihead_attention_layer(tem_work_memories, tem_work_memories,
                                                                     tem_work_memories)  
                    work_att_out =self.work_task_head[i](work_att_per_out)+work_att_out
                work_memories = work_memories + work_att_out  
                work_memories=self.norm2(work_memories)
            
            mask = tr.generate_square_subsequent_mask(
                work_memories.shape[0])
            mask = mask.to(work_memories.device)

            
            if self.long_enabled:  
                output = self.dec_modules(
                    work_memories,
                    memory=memory,
                    tgt_mask=mask,
                )  
            else:
                output = self.dec_modules(
                    work_memories,
                    src_mask=mask,
                )

        
            score = self.classifier(output) 

        return score.transpose(0, 1) 

    def add_samples_to_mem(self, cilsettask, data, m, type_sampling='icarl'):
        
        if type_sampling == 'icarl':
            for class_id, videos in data.items():
                data_class = {class_id: videos}    
                class_loader = cilsettask.get_dataloader(data_class)
                features = []
                video_names = []
                for _, video_name, video, _ in class_loader:
                    video = video.to(self.device)
                    feature = self.feature_encoder(video).data.cpu().numpy()
                    feature = feature / np.linalg.norm(feature)
                    features.append(feature[0])
                    video_names.append(video_name)

                features = np.array(features)
                class_mean = np.mean(features, axis=0)
                class_mean = class_mean / np.linalg.norm(class_mean)  

                exemplar_set = []
                exemplar_features = []  
                list_selected_idx = []
                for k in range(m):
                    S = np.sum(exemplar_features, axis=0)
                    phi = features
                    mu = class_mean
                    mu_p = 1.0 / (k + 1) * (phi + S)
                    mu_p = mu_p / np.linalg.norm(mu_p)
                    
                    dist = np.sqrt(np.sum((mu - mu_p) ** 2, axis=1))
                    if k <= len(dist) - 2:
                        list_idx = np.argpartition(dist, k)[:k + 1]
                    elif k < len(dist):
                        fixed_k = len(dist) - 2
                        list_idx = np.argpartition(dist, fixed_k)[:fixed_k + 2]
                    else:
                        break

                    for idx in list_idx:
                        if idx not in list_selected_idx:
                            list_selected_idx.append(idx)
                            exemplar_set.append(video_names[idx][0])
                            exemplar_features.append(features[idx])
                            break

                    
                self.memory[class_id] = exemplar_set

            self.memory = {class_id: videos[:m] for class_id, videos in self.memory.items()}
        else:
            self.memory = {**self.memory, **data}
            for class_id, videos in self.memory.items():
                random.shuffle(videos)
                self.memory[class_id] = videos[:m]

        for class_id, videos in self.memory.items():
            print('Memory... Class: {}, num videos: {}'.format(class_id, len(videos)))


@registry.register('LSTR')
class LSTRStream(LSTR):

    def __init__(self, cfg):
        super(LSTRStream, self).__init__(cfg)

        
        
        
        self.long_memories_cache = None
        self.compressed_long_memories_cache = None

    def stream_inference(self,
                         long_visual_inputs,
                         long_motion_inputs,
                         work_visual_inputs,
                         work_motion_inputs,
                         memory_key_padding_mask=None):
        assert self.long_enabled, 'Long-term memory cannot be empty for stream inference'
        assert len(self.enc_modules) > 0, 'LSTR encoder cannot be disabled for stream inference'

        if (long_visual_inputs is not None) and (long_motion_inputs is not None):
            
            long_memories = self.feature_head_long(
                long_visual_inputs,
                long_motion_inputs,
            ).transpose(0, 1)

            if self.long_memories_cache is None:
                self.long_memories_cache = long_memories
            else:
                self.long_memories_cache = torch.cat((
                    self.long_memories_cache[1:], long_memories
                ))

            long_memories = self.long_memories_cache
            pos = self.pos_encoding.pe[:self.long_memory_num_samples, :]

            enc_queries = [
                enc_query.weight.unsqueeze(1).repeat(1, long_memories.shape[1], 1)
                if enc_query is not None else None
                for enc_query in self.enc_queries
            ]

            
            long_memories = self.enc_modules[0].stream_inference(enc_queries[0], long_memories, pos,
                                                                 memory_key_padding_mask=memory_key_padding_mask)
            self.compressed_long_memories_cache  = long_memories
            for enc_query, enc_module in zip(enc_queries[1:], self.enc_modules[1:]):
                if enc_query is not None:
                    long_memories = enc_module(enc_query, long_memories)
                else:
                    long_memories = enc_module(long_memories)
        else:
            long_memories = self.compressed_long_memories_cache

            enc_queries = [
                enc_query.weight.unsqueeze(1).repeat(1, long_memories.shape[1], 1)
                if enc_query is not None else None
                for enc_query in self.enc_queries
            ]

            
            for enc_query, enc_module in zip(enc_queries[1:], self.enc_modules[1:]):
                if enc_query is not None:
                    long_memories = enc_module(enc_query, long_memories)
                else:
                    long_memories = enc_module(long_memories)

        
        if self.long_enabled:
            memory = long_memories

        if self.work_enabled:
            
            work_memories = self.pos_encoding(self.feature_head_work(
                work_visual_inputs,
                work_motion_inputs,
            ).transpose(0, 1), padding=self.long_memory_num_samples)

            
            mask = tr.generate_square_subsequent_mask(
                work_memories.shape[0])
            mask = mask.to(work_memories.device)

            
            if self.long_enabled:
                output = self.dec_modules(
                    work_memories,
                    memory=memory,
                    tgt_mask=mask,
                )
            else:
                output = self.dec_modules(
                    work_memories,
                    src_mask=mask,
                )

        
        score = self.classifier(output)

        return score.transpose(0, 1)
