import torch.nn as nn
import torch.nn.functional as F
import torch
from utils.buffer import Buffer_audio
from einops import rearrange


class MATS_der_pp(nn.Module):

    def __init__(self, model: nn.Module, batch_size: int, device, alpha: float, mem_args,
                 core_video_ratio: float, **kwargs) -> None:
        super().__init__()
        self.backbone = model
        self.device = device
        self.batch_size = batch_size

        # Der hyperparameter
        self.alpha = alpha
        # Buffer initialization
        self.buffer = Buffer_audio(**mem_args, device=self.device)

        self.core_video_ratio = core_video_ratio
        self.backbone.transformer.mask_comp_ratio_v = core_video_ratio
        self.num_freq_tokens = 8

        self.num_frames = 4

        self._req_penalty = True
        self._req_opt = False

    def forward(self, inputs):
        if 'retrieval' in inputs and inputs['retrieval']:
            output = self.backbone(inputs)

            return output

        if self.training:
            # Does not generate computational graph while extracting intermediate embeddings.
            self.backbone.requires_grad_(requires_grad=False)

            # Frame selection
            frame_ids = self.mats_frame_selection(inputs['video_data'])
            inputs['video_data'] = torch.gather(inputs['video_data'], dim=1, index=frame_ids[:,:self.num_frames,None,None,None].repeat(1,1,inputs['video_data'].shape[2],
                                                                                                                        inputs['video_data'].shape[3],
                                                                                                                        inputs['video_data'].shape[4]))
            # Current patch selection
            v_sort_ids = self.mats_sampling(video=inputs['video_data'])
            N, L_V = v_sort_ids.shape

            self.backbone.requires_grad_(requires_grad=True)

            # Extract core patches based on attention-based importance score
            core_video_patches = self.importance_based_patch_selection(
                v_sort_ids, inputs['video_data']
            )
            core_inputs = {
                "video_data": core_video_patches,
                "audio_data": inputs['audio_data'],
                "att_map_av_ids": v_sort_ids,
            }
            buf_inputs = None
            if not self.buffer.is_empty():
                # Load past core patches
                buf_inputs = self.buffer.get_data(self.batch_size)
                buf_inputs = {k: v.cuda(self.device, non_blocking=True) for k, v in buf_inputs.items()}

                self.backbone.requires_grad_(requires_grad=False)
                # Past patch selection
                buf_v_sort_ids = self.mats_sampling(video=buf_inputs['video_data'])

                self.backbone.requires_grad_(requires_grad=True)

                buf_core_video_patches = self.importance_based_patch_selection(
                    buf_v_sort_ids, buf_inputs['video_data']
                )

                core_inputs['video_data'] = torch.cat((core_inputs['video_data'], buf_core_video_patches), dim=0)
                core_inputs['audio_data'] = torch.cat((core_inputs['audio_data'], buf_inputs['audio_data']), dim=0)
                core_inputs['att_map_av_ids'] = torch.cat((core_inputs['att_map_av_ids'], buf_v_sort_ids), dim=0)

            output = self.backbone(core_inputs)

            if buf_inputs is not None:
                buf_logits_a = output["audio_output"][-self.batch_size:]
                buf_logits_v = output["video_output"][-self.batch_size:]

                penalty = self.alpha * F.mse_loss(buf_inputs['logits_a'], buf_logits_a) + \
                          self.alpha * F.mse_loss(buf_inputs['logits_v'], buf_logits_v)
            else:
                penalty = torch.Tensor([0]).cuda(self.device, non_blocking=True)

            output['penalty_loss'] = penalty

            self.buffer.add_data(video_data=inputs['video_data'], audio_data=inputs['audio_data'],
                                 logits_a=output["audio_output"][:N],
                                 logits_v=output["video_output"][:N])
        else:
            output = self.backbone(inputs)

        return output

    def mats_sampling(self, video):
        N, T, C, H, W = video.shape
        x_v = self.backbone.transformer.patch_embed_v(video.reshape(N*T, C, H, W))
        _, L, D = x_v.shape
        patch_embed_vectors = x_v.detach().clone().reshape(N, T, L, D)

        distance = torch.norm(patch_embed_vectors[:,:T-1,:,:] - patch_embed_vectors[:,1:,:,:], p=2, dim=3)
        importance = torch.cat((distance[:,0,:], distance.flatten(1)), dim=1)
        ids_sorted = torch.argsort(importance, dim=1, descending=True)

        return ids_sorted

    def mats_frame_selection(self, video):
        N, T, C, H, W = video.shape
        x_v = self.backbone.transformer.patch_embed_v(video.reshape(N*T, C, H, W))
        _, L, D = x_v.shape
        patch_embed_vectors = x_v.detach().clone().reshape(N, T, L, D)

        importance = torch.norm(patch_embed_vectors[:,:T-1,:,:] - patch_embed_vectors[:,1:,:,:], p=2, dim=3).flatten(1)

        ids_sorted = torch.argsort(importance, dim=1, descending=True)
        num_compressed_tokens = int((1 - self.core_video_ratio) * (T-1) * L)
        ids_restore = torch.argsort(ids_sorted, dim=1)

        mask = torch.zeros([N, (T-1)*L], device=self.device)
        mask[:,:num_compressed_tokens] = 1
        score = torch.gather(mask, dim=1, index=ids_restore).reshape(N, T-1, L).sum(dim=2)
        ids = torch.multinomial(score, T-1, replacement=False)
        ids = torch.cat([torch.zeros(ids.shape[0], 1, dtype=torch.int64, device=self.device),
                         (ids+1)], dim=1)
        ids, _ = torch.sort(ids)

        return ids



    def compute_core_video_indices(self, v_val):
        v_val_ids = torch.argsort(v_val, dim=1, descending=True)
        return v_val_ids


    def importance_based_patch_selection(self, att_av_ids, video_data):
        """
        Given importance information and original data, extract core patches
        """
        # Extract core video patches
        video_patches = video_data.transpose(1, 2)
        video_patches = rearrange(video_patches, 'b c t (h p0) (w p1) -> b c (t h w) p0 p1', p0=16, p1=16)
        video_patches = video_patches.transpose(1, 2)  # B x patches x c x 16 x 16
        N, L_V = video_patches.shape[:2]
        num_core_video_tokens = int(L_V * (1 - self.core_video_ratio))
        video_keep_ids = att_av_ids[:, :num_core_video_tokens]
        core_video_patches = torch.gather(video_patches, dim=1,
                                          index=video_keep_ids[:, :, None, None, None].repeat(1, 1, 3, 16, 16))

        return core_video_patches



