import torch
from ..diffusionGrid_sampling import forward_reward

def sa_double_loss(env, sa_forward_net, sa_backward_net, sa_logz, forward_net, backward_net, logz,
                   rnd, rnd_tgt, beta_e = 0.25, beta_i = 1.0, beta_sn=1.0, training=True):

    sa_log_p = torch.zeros(env.batch_size)
    log_p = torch.zeros(env.batch_size)
    r_i = torch.zeros(env.batch_size)

    for t in range(env.T):

        active_ids = env.batch_ids[env.stopped == 0]
        sa_pol = env.get_forward_pol(sa_forward_net)
        actions = env.get_forward_actions(sa_pol, training=training)
        sa_log_p[active_ids] += torch.log(sa_pol[active_ids, actions[active_ids]].clamp(1e-9))

        pol = env.get_forward_pol(forward_net)
        log_p[active_ids] += torch.log(pol[active_ids, actions[active_ids]].clamp(1e-9))

        r_i = r_i + torch.linalg.norm(rnd(env.obs()) - rnd_tgt(env.obs()), dim=1)

        env.apply(actions)
        sa_back_pol = env.get_backward_pol(sa_backward_net)
        mask = actions[active_ids] < env.backward_action_dim
        sa_log_p[active_ids[mask]] -= torch.log(sa_back_pol[active_ids[mask], actions[active_ids][mask]].clamp(1e-9))

        back_pol = env.get_backward_pol(backward_net)
        log_p[active_ids[mask]] -= torch.log(back_pol[active_ids[mask], actions[active_ids][mask]].clamp(1e-9))

    log_r = env.log_reward()
    r = ((beta_e * log_r).exp() + (r_i) ** beta_i) ** beta_sn
    sa_log_r = torch.log(r + 1e-9)

    sa_tb_loss = (sa_logz + sa_log_p - sa_log_r).pow(2).mean()
    tb_loss = (logz + log_p - log_r).pow(2).mean()
    rnd_loss = r_i.mean()

    return sa_tb_loss, tb_loss, rnd_loss

def cosh_tb(env, fnet, bnet, logz):
    gfn_log_r = forward_reward(env, fnet, bnet, logz)
    log_r = env.log_reward()
    log_ratio = gfn_log_r - log_r
    return (torch.exp(log_ratio) + torch.exp(-log_ratio) - 2.0).mean()

def linex_tb(env, fnet, bnet, logz):
    gfn_log_r = forward_reward(env, fnet, bnet, logz)
    log_r = env.log_reward()
    log_ratio = gfn_log_r - log_r
    return (torch.exp(log_ratio) - log_ratio - 1.0).mean()

def dual_forward_trajectory_log_prob(env, fnet, bnet, teacher_fnet, teacher_bnet):
    student_log_p = torch.zeros(env.batch_size)
    teacher_log_p = torch.zeros(env.batch_size)

    for t in range(env.T):
        active_ids = env.batch_ids[env.stopped == 0]
        pol = env.get_forward_pol(fnet)
        teacher_pol = env.get_forward_pol(teacher_fnet)
        actions = env.get_forward_actions(pol, training=True)
        student_log_p[active_ids] += torch.log(pol[active_ids, actions[active_ids]].clamp(1e-9))
        teacher_log_p[active_ids] += torch.log(teacher_pol[active_ids, actions[active_ids]].clamp(1e-9))

        env.apply(actions)
        back_pol = env.get_backward_pol(bnet)
        teacher_back_pol = env.get_backward_pol(teacher_bnet)
        mask = actions[active_ids] < env.backward_action_dim
        student_log_p[active_ids[mask]] -= torch.log(back_pol[active_ids[mask], actions[active_ids][mask]].clamp(1e-9))
        teacher_log_p[active_ids[mask]] -= torch.log(teacher_back_pol[active_ids[mask], actions[active_ids][mask]].clamp(1e-9))

    return student_log_p, teacher_log_p