import numpy as np
import torch
import random
import os

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def get_att_entropy_per_batch(attention_maps, traj_mask):
    n_layers = attention_maps.shape[1]
    att_map_mask = torch.cat([traj_mask, traj_mask, traj_mask], dim=-1).unsqueeze(1).repeat(1, n_layers, 1)
    att_entropy_for_each_token = get_att_entropy(attention_maps)
    entropy_per_batch_per_layer = torch.sum(att_entropy_for_each_token * att_map_mask, dim=-1) / torch.sum(att_map_mask,dim=-1)
    attention_loss = torch.mean(entropy_per_batch_per_layer,0) # mean over batch
    return attention_loss



def get_att_entropy(attention_maps):
    mean_att_entropy_for_each_token = torch.sum(
        -attention_maps * torch.log(attention_maps + 1e-10), dim=-1)
    return mean_att_entropy_for_each_token

