import time
import torch 
import torch.nn as nn  
import torch.nn.functional as F 

import numpy as np  
def safe_cross_entropy(p, logq, dim=-1):
    safe_logq = torch.where(p == 0, torch.ones_like(logq), logq)
    return -torch.sum(p * safe_logq, dim=dim)


def lossop_func(logits_train, ideal_probs_train, R_train, baseline, lengths, gamma_decay, entropy_weight, ):


    (max_time_step, n_train, n_choices,) = ideal_probs_train.shape
    mask_length_np = np.tile(np.arange(0, max_time_step), (n_train, 1)  
                             ).astype(int) < np.tile(lengths, (max_time_step, 1)).transpose()
    mask_length_np = mask_length_np.transpose().astype(float)  
    mask_length = torch.tensor(mask_length_np, requires_grad=False)  
    entropy_gamma_decay = np.array([gamma_decay ** t for t in range(max_time_step)])  
    entropy_decay_mask_np = np.tile(entropy_gamma_decay,
                                    (n_train, 1)).transpose() * mask_length_np  
    entropy_decay_mask = torch.tensor(entropy_decay_mask_np, requires_grad=False)  
    probs = torch.nn.functional.softmax(logits_train, dim=2) 
    logprobs = torch.nn.functional.log_softmax(logits_train, dim=2)  
    neglogp_per_step = safe_cross_entropy(ideal_probs_train, logprobs, dim=2)  
    neglogp = torch.sum(neglogp_per_step * mask_length, dim=0) 
    loss_gp = torch.mean((R_train - baseline) * neglogp)
    entropy_per_step = safe_cross_entropy(probs, logprobs, dim=2)  
    entropy = torch.sum(entropy_per_step * entropy_decay_mask, dim=0)  
    loss_entropy = -entropy_weight * torch.mean(entropy)
    loss = loss_gp + loss_entropy
    return loss

def lossvar_func(logits_train, ideal_probs_train, R_train, baseline, lengths, gamma_decay, entropy_weight,):


    (max_time_step, n_train, n_choices,) = ideal_probs_train.shape
    mask_length_np = np.tile(np.arange(0, max_time_step), (n_train, 1)  
                             ).astype(int) < np.tile(lengths, (max_time_step, 1)).transpose()
    mask_length_np = mask_length_np.transpose().astype(float) 
    mask_length = torch.tensor(mask_length_np, requires_grad=False) 
    entropy_gamma_decay = np.array([gamma_decay ** t for t in range(max_time_step)])  
    entropy_decay_mask_np = np.tile(entropy_gamma_decay,
                                    (n_train, 1)).transpose() * mask_length_np  
    entropy_decay_mask = torch.tensor(entropy_decay_mask_np, requires_grad=False)  
    probs = torch.nn.functional.softmax(logits_train, dim=2)  
    logprobs = torch.nn.functional.log_softmax(logits_train, dim=2)  
    neglogp_per_step = safe_cross_entropy(ideal_probs_train, logprobs, dim=2)  
    neglogp = torch.sum(neglogp_per_step * mask_length, dim=0) 
    loss_gp = torch.mean((R_train - baseline) * neglogp)
    entropy_per_step = safe_cross_entropy(probs, logprobs, dim=2)  
    entropy = torch.sum(entropy_per_step * entropy_decay_mask, dim=0)  
    loss_entropy = -entropy_weight * torch.mean(entropy)
    loss = loss_gp + loss_entropy
    return loss


def losstwovar_func(logits_train, ideal_probs_train, R_train, baseline, lengths, gamma_decay, entropy_weight, additional_mask):
    (max_time_step, n_train, n_choices,) = ideal_probs_train.shape
    mask_length_np = np.tile(np.arange(0, max_time_step), (n_train, 1)
                             ).astype(int) < np.tile(lengths, (max_time_step, 1)).transpose()
    mask_length_np = mask_length_np.transpose().astype(float)  
    mask_length = torch.tensor(mask_length_np, requires_grad=False)  
    combined_mask = mask_length * additional_mask  
    entropy_gamma_decay = np.array([gamma_decay ** t for t in range(max_time_step)])  
    entropy_decay_mask_np = np.tile(entropy_gamma_decay,
                                    (n_train, 1)).transpose() * mask_length_np  
    entropy_decay_mask = torch.tensor(entropy_decay_mask_np, requires_grad=False) 
    entropy_decay_mask = entropy_decay_mask * additional_mask  
    probs = torch.nn.functional.softmax(logits_train, dim=2)  
    logprobs = torch.nn.functional.log_softmax(logits_train, dim=2)  
    neglogp_per_step = safe_cross_entropy(ideal_probs_train, logprobs, dim=2) 
    neglogp = torch.sum(neglogp_per_step * combined_mask, dim=0)  
    loss_gp = torch.mean((R_train - baseline) * neglogp)
    entropy_per_step = safe_cross_entropy(probs, logprobs, dim=2)  
    entropy = torch.sum(entropy_per_step * entropy_decay_mask, dim=0)  
    loss_entropy = -entropy_weight * torch.mean(entropy)
    loss = loss_gp + loss_entropy
    return loss