import torch
from ..peptide_sampling import forward_reward
import torch.nn.functional as F

def sa_double_loss(env, sa_forward_net, sa_logz, forward_net, logz,
                   rnd, rnd_tgt, beta_e = 0.25, beta_i = 1.0, beta_sn=1.0, training=True):

    sa_log_p = torch.zeros(env.batch_size)
    log_p = torch.zeros(env.batch_size)
    r_i = torch.zeros(env.batch_size)

    for t in range(env.seq_size):

        active = env.alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break
        s_sub = env.state.index_select(0, active) # [B_t, L]
        
        sa_logits = sa_forward_net(s_sub) # [B_t, 20]
        logits = forward_net(s_sub) # [B_t, 20]
        actions = env.get_actions(sa_logits, training)
        sa_log_p[active] += F.log_softmax(sa_logits, dim=-1).gather(1, actions.unsqueeze(1)).squeeze(1)
        log_p[active] += F.log_softmax(logits, dim=-1).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        delta = rnd(s_sub.float()) - rnd_tgt(s_sub.float())
        r_i[active] = r_i[active] + torch.linalg.norm(delta, dim=1)
        #r_i = r_i + torch.linalg.norm(rnd(s_sub.float()) - rnd_tgt(s_sub.float()), dim=1)

        with torch.no_grad():
            env.state[active, t] = actions
            
        env.alive[active] = (actions != 0)



    log_r = env.log_reward()
    r = ((beta_e * log_r).exp() + (r_i) ** beta_i) ** beta_sn
    sa_log_r = torch.log(r + 1e-9)

    sa_tb_loss = (sa_logz + sa_log_p - sa_log_r).pow(2).mean()
    tb_loss = (logz + log_p - log_r).pow(2).mean()
    rnd_loss = r_i.mean()

    return sa_tb_loss, tb_loss, rnd_loss

def cosh_tb(env, fnet, logz):
    gfn_log_r = forward_reward(env, fnet, logz)
    log_r = env.log_reward()
    log_ratio = gfn_log_r - log_r
    return (torch.exp(log_ratio) + torch.exp(-log_ratio) - 2.0).mean()

def linex_tb(env, fnet, logz):
    gfn_log_r = forward_reward(env, fnet, logz)
    log_r = env.log_reward()
    log_ratio = gfn_log_r - log_r
    return (torch.exp(log_ratio) - log_ratio - 1.0).mean()