import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np




def drop(x, attention, drop_tokens, option='drop', indices=None, return_indices=True):
    """
        option: [drop, detach]
    """
    # drop based on entropy?
    if abs(drop_tokens) < 1:
        N = max(1, int((len(x)-1)  * (1-abs(drop_tokens))))
    else:
        N = max(1, len(x) -1 - int(abs(drop_tokens)))
    # measure entropy
    entropy = -torch.sum(attention * torch.log(attention), dim=-1)
    if len(entropy.shape) == 3: # use multi-head attention
        entropy = torch.mean(entropy, dim=1)
    # select low entropy
    if indices is None or N > indices.shape[1]:
        _, idx = torch.topk(-entropy[:,1:], N, dim=1) # use class embedding
        # concat idx to 0 index
        idx = idx + 1 # add class embedding
        idx = torch.cat([torch.zeros(idx.shape[0], 1).to(idx.device).long(), idx], dim=1) # selected indices
        idx = torch.sort(idx, dim=1)[0] # sort indices
        indices = idx.clone()
    else:
        idx = indices.clone()
    N += 1
    # unroll x
    if option == 'drop':
        idx = idx + torch.arange(0, x.shape[1]).unsqueeze(1).to(idx.device) * x.shape[0]
        x = x.permute(1, 0, 2).reshape(-1, x.shape[-1])
        x = torch.index_select(x, 0, idx.view(-1))
        x = x.reshape(-1, N, x.shape[-1]).permute(1, 0, 2)
    elif option == 'detach': # make mask for idx and apply it to x.
        # x (N, B, D), idx (B, N)
        mask = torch.zeros(x.shape[0], x.shape[1]).to(x.device)
        mask.scatter_(0, idx.T, 1)
        x = x * mask.unsqueeze(-1)  + x.detach() * (1-mask).unsqueeze(-1)
    elif option == 'merge':
        _, neg_idx = torch.topk(-entropy[:,1:], entropy.shape[1]-N, dim=1, largest=False) # use class embedding
        neg_idx += 1 # including class embedding
        idx = idx + torch.arange(0, x.shape[1]).unsqueeze(1).to(idx.device) * x.shape[0]
        neg_idx = neg_idx + torch.arange(0, x.shape[1]).unsqueeze(1).to(idx.device) * x.shape[0]
        x = x.permute(1, 0, 2).reshape(-1, x.shape[-1])
        merged_target = torch.index_select(x, 0, neg_idx.view(-1))
        merged_target = merged_target.reshape(-1, entropy.shape[1] - N, x.shape[-1]).permute(1, 0, 2)
        merged_target = merged_target.mean(dim=0, keepdim=True)
        # selected
        x = torch.index_select(x, 0, idx.view(-1))
        x = x.reshape(-1, N, x.shape[-1]).permute(1, 0, 2)
        x = torch.cat([x, merged_target], dim=0)

    if return_indices:
        return x, indices
    return x


def all_trials(x):
    breakpoint()


