import torch.nn as nn
import torch.nn.functional as F
import torch
from utils.buffer import Buffer_audio
from einops import rearrange

class Random_comp(nn.Module):

    def __init__(self, model: nn.Module, batch_size: int, num_patches_v: int, num_patches_a: int,
                 core_video_ratio: float, core_audio_ratio: float, alpha: float, mem_args, device,
                 **kwargs) -> None:
        super().__init__()
        self.backbone = model
        self.device = device
        self.batch_size = batch_size

        self.num_patches_v = num_patches_v
        self.num_patches_a = num_patches_a
        self.core_video_ratio = core_video_ratio
        self.core_audio_ratio = core_audio_ratio

        # Buffer initialization
        self.buffer = Buffer_audio(**mem_args, device=self.device)
        self.backbone.transformer.mask_comp_ratio_v = core_video_ratio
        self.backbone.transformer.mask_comp_ratio_a = core_audio_ratio

        self.alpha = alpha

        self._req_penalty = True
        self._req_opt = False

    def forward(self, inputs):

        if self.training:
            N = len(inputs["video_data"])
            L_V = self.num_patches_v
            L_A = self.num_patches_a
            noise_v = torch.rand(N, L_V, device=self.device)
            noise_a = torch.rand(N, L_A, device=self.device)
            # sort noise for random compression
            ids_v_shuffle = torch.argsort(noise_v, dim=1)
            ids_a_shuffle = torch.argsort(noise_a, dim=1)

            rand_video_patches, rand_audio_patches = self.importance_based_patch_selection(
                ids_v_shuffle, ids_a_shuffle, inputs['video_data'], inputs['audio_data']
            )
            rand_inputs = {
                "video_data": rand_video_patches,
                "audio_data": rand_audio_patches,
                "att_map_av_ids": ids_v_shuffle,
                "att_map_va_ids": ids_a_shuffle,
            }
            buf_inputs = None
            if not self.buffer.is_empty():
                # Load past core patches
                buf_inputs = self.buffer.get_data(self.batch_size)
                buf_inputs = {k: v.cuda(self.device, non_blocking=True) for k, v in buf_inputs.items()}

                buf_noise_v = torch.rand(self.batch_size, L_V, device=self.device)
                buf_noise_a = torch.rand(self.batch_size, L_A, device=self.device)
                # sort noise for random compression
                buf_ids_v_shuffle = torch.argsort(buf_noise_v, dim=1)
                buf_ids_a_shuffle = torch.argsort(buf_noise_a, dim=1)

                buf_rand_video_patches, buf_rand_audio_patches = self.importance_based_patch_selection(
                    buf_ids_v_shuffle, buf_ids_a_shuffle, buf_inputs['video_data'], buf_inputs['audio_data']
                )

                rand_inputs['video_data'] = torch.cat((rand_inputs['video_data'], buf_rand_video_patches), dim=0)
                rand_inputs['audio_data'] = torch.cat((rand_inputs['audio_data'], buf_rand_audio_patches), dim=0)
                rand_inputs['att_map_av_ids'] = torch.cat((rand_inputs['att_map_av_ids'], buf_ids_v_shuffle),
                                                          dim=0)
                rand_inputs['att_map_va_ids'] = torch.cat((rand_inputs['att_map_va_ids'], buf_ids_a_shuffle),
                                                          dim=0)

            output = self.backbone(rand_inputs)

            if buf_inputs is not None:
                buf_logits_a = output["audio_output"][-self.batch_size:]
                buf_logits_v = output["video_output"][-self.batch_size:]

                penalty = self.alpha * F.mse_loss(buf_inputs['logits_a'], buf_logits_a) + \
                          self.alpha * F.mse_loss(buf_inputs['logits_v'], buf_logits_v)

            else:
                penalty = torch.Tensor([0]).cuda(self.device, non_blocking=True)

            output['penalty_loss'] = penalty

            self.buffer.add_data(video_data=inputs['video_data'], audio_data=inputs['audio_data'],
                                 logits_a=output["audio_output"][:N],
                                 logits_v=output["video_output"][:N])

            return output

        else:
            return self.backbone(inputs)

    def importance_based_patch_selection(self, att_av_ids, att_va_ids, video_data, audio_data):
        """
        Given importance information and original data, extract core patches
        """
        # Extract core video patches
        video_patches = video_data.transpose(1, 2)
        video_patches = rearrange(video_patches, 'b c t (h p0) (w p1) -> b c (t h w) p0 p1', p0=16, p1=16)
        video_patches = video_patches.transpose(1, 2)  # B x patches x c x 16 x 16
        N, L_V = video_patches.shape[:2]
        num_core_video_tokens = int(L_V * (1 - self.core_video_ratio))
        video_keep_ids = att_av_ids[:, :num_core_video_tokens]
        core_video_patches = torch.gather(video_patches, dim=1,
                                          index=video_keep_ids[:, :, None, None, None].repeat(1, 1, 3, 16, 16))

        # Extract core audio patches
        audio_patches = audio_data
        audio_patches = rearrange(audio_patches, 'b c (t p0) (f p1) -> b c (t f) p0 p1', p0=16, p1=16)
        audio_patches = audio_patches.transpose(1, 2)
        N, L_A = audio_patches.shape[:2]
        num_core_audio_tokens = int(L_A * (1 - self.core_audio_ratio))
        audio_keep_ids = att_va_ids[:, :num_core_audio_tokens]
        core_audio_patches = torch.gather(audio_patches, dim=1,
                                          index=audio_keep_ids[:, :, None, None, None].repeat(1, 1, 1, 16, 16))

        return core_video_patches, core_audio_patches
