import torch
from torch.nn import functional as F


def _clip_actions(algo, actions):
    epsilon = 1e-6
    lower = torch.from_numpy(algo._env_spec.action_space.low).to(algo.device) + epsilon
    upper = torch.from_numpy(algo._env_spec.action_space.high).to(algo.device) - epsilon

    clip_up = (actions > upper).float()
    clip_down = (actions < lower).float()
    with torch.no_grad():
        clip = (upper - actions) * clip_up + (lower - actions) * clip_down

    return actions + clip


def update_loss_qf(
    algo,
    tensors,
    v,
    obs,
    actions,
    next_obs,
    dones,
    rewards,
    policy,
):
    with torch.no_grad():
        alpha = algo.log_alpha.param.exp()

    q1_pred = algo.qf1(obs, actions)[0].flatten(1)
    q2_pred = algo.qf2(obs, actions)[0].flatten(1)

    full_obs = torch.cat([obs[:, 0, :].unsqueeze(1), next_obs], dim=1)
    next_action_dists, *_ = policy(full_obs)
    if hasattr(next_action_dists, "rsample_with_pre_tanh_value"):
        new_next_actions_pre_tanh, new_next_actions = (
            next_action_dists.rsample_with_pre_tanh_value()
        )
        new_next_action_log_probs = next_action_dists.log_prob(
            new_next_actions, pre_tanh_value=new_next_actions_pre_tanh
        )
    else:
        new_next_actions = next_action_dists.rsample()
        new_next_actions = _clip_actions(algo, new_next_actions)
        new_next_action_log_probs = next_action_dists.log_prob(new_next_actions)
    new_next_actions = new_next_actions[:, 1:, :]
    new_next_action_log_probs = new_next_action_log_probs[:, 1:]

    hidden_states_target_qf1, hidden_states_target_qf2 = [], []
    last_hidden_state_target_qf1, last_hidden_state_target_qf2 = None, None
    for i in range(obs.shape[1]):
        _, last_hidden_state_target_qf1 = algo.target_qf1(
            obs[:, i, :].unsqueeze(1),
            actions[:, i, :].unsqueeze(1),
            hidden_states=last_hidden_state_target_qf1,
        )
        hidden_states_target_qf1.append(last_hidden_state_target_qf1)
        _, last_hidden_state_target_qf2 = algo.target_qf2(
            obs[:, i, :].unsqueeze(1),
            actions[:, i, :].unsqueeze(1),
            hidden_states=last_hidden_state_target_qf2,
        )
        hidden_states_target_qf2.append(
            last_hidden_state_target_qf2
        )  # num_layers, batch_size, H_cell

    hns_qf1 = torch.concat([hn for (hn, cn) in hidden_states_target_qf1], dim=1)
    cns_qf1 = torch.concat([cn for (hn, cn) in hidden_states_target_qf1], dim=1)
    hns_qf2 = torch.concat([hn for (hn, cn) in hidden_states_target_qf2], dim=1)
    cns_qf2 = torch.concat([cn for (hn, cn) in hidden_states_target_qf2], dim=1)

    target_q_values = torch.min(
        algo.target_qf1(
            next_obs.flatten(0, 1).unsqueeze(1),
            new_next_actions.flatten(0, 1).unsqueeze(1),
            hidden_states=(hns_qf1, cns_qf1),
        )[0].flatten(),
        algo.target_qf2(
            next_obs.flatten(0, 1).unsqueeze(1),
            new_next_actions.flatten(0, 1).unsqueeze(1),
            hidden_states=(hns_qf2, cns_qf2),
        )[0].flatten(),
    ).reshape_as(rewards)
    target_q_values = target_q_values - alpha * new_next_action_log_probs
    target_q_values = target_q_values * algo.discount

    with torch.no_grad():
        q_target = rewards + target_q_values * (1.0 - dones)
        q_target = q_target.flatten(1)

    # critic loss weight: 0.5
    assert q1_pred.shape == q_target.shape

    # # use only one loss per each trajectory
    # mask = torch.randint(0, rewards.shape[1], (rewards.shape[0],)).to(q1_pred.device)
    # q1_pred = q1_pred.gather(1, mask.unsqueeze(1)).squeeze(1)
    # q2_pred = q2_pred.gather(1, mask.unsqueeze(1)).squeeze(1)
    # q_target = q_target.gather(1, mask.unsqueeze(1)).squeeze(1)

    # # use 5 losses per each trajectories
    # mask = torch.randint(0, rewards.shape[1], (rewards.shape[0], 5)).to(q1_pred.device)
    # q1_pred = q1_pred.gather(1, mask).flatten()
    # q2_pred = q2_pred.gather(1, mask).flatten()
    # q_target = q_target.gather(1, mask).flatten()

    # use every losses
    mask = (
        torch.arange(0, rewards.shape[1])
        .unsqueeze(0)
        .repeat(rewards.shape[0], 1)
        .to(q1_pred.device)
    )
    q1_pred = q1_pred.gather(1, mask).flatten()
    q2_pred = q2_pred.gather(1, mask).flatten()
    q_target = q_target.gather(1, mask).flatten()

    loss_qf1 = F.mse_loss(q1_pred, q_target) * 0.5
    loss_qf2 = F.mse_loss(q2_pred, q_target) * 0.5

    tensors.update(
        {
            "QTargetsMean": q_target.mean(),
            "QTdErrsMean": ((q_target - q1_pred).mean() + (q_target - q2_pred).mean())
            / 2,
            "Q1PredMean": q1_pred.mean(),
            "Q2PredMean": q2_pred.mean(),
            "LossQf1": loss_qf1,
            "LossQf2": loss_qf2,
        }
    )
    v.update(
        {
            "mask": mask,
        }
    )


def update_loss_sacp(
    algo,
    tensors,
    v,
    obs,
    policy,
):
    with torch.no_grad():
        alpha = algo.log_alpha.param.exp()

    action_dists, *_ = policy(obs)
    if hasattr(action_dists, "rsample_with_pre_tanh_value"):
        new_actions_pre_tanh, new_actions = action_dists.rsample_with_pre_tanh_value()
        new_action_log_probs = action_dists.log_prob(
            new_actions, pre_tanh_value=new_actions_pre_tanh
        )
    else:
        new_actions = action_dists.rsample()
        new_actions = _clip_actions(algo, new_actions)
        new_action_log_probs = action_dists.log_prob(new_actions)  # [B, L]

    min_q_values = torch.min(
        algo.qf1(obs, new_actions)[0].flatten(),
        algo.qf2(obs, new_actions)[0].flatten(),
    )

    mask = v["mask"]
    min_q_values = (
        min_q_values.reshape_as(new_action_log_probs).gather(1, mask).flatten()
    )
    new_action_log_probs = new_action_log_probs.gather(1, mask).flatten()

    loss_sacp = (alpha * new_action_log_probs.flatten() - min_q_values).mean()

    tensors.update(
        {
            "SacpNewActionLogProbMean": new_action_log_probs.mean(),
            "LossSacp": loss_sacp,
        }
    )

    v.update(
        {
            "new_action_log_probs": new_action_log_probs,
        }
    )


def update_loss_alpha(
    algo,
    tensors,
    v,
):
    loss_alpha = (
        -algo.log_alpha.param
        * (v["new_action_log_probs"].detach() + algo._target_entropy)
    ).mean()

    tensors.update(
        {
            "Alpha": algo.log_alpha.param.exp(),
            "LossAlpha": loss_alpha,
        }
    )


def update_targets(algo):
    """Update parameters in the target q-functions."""
    target_qfs = [algo.target_qf1, algo.target_qf2]
    qfs = [algo.qf1, algo.qf2]
    for target_qf, qf in zip(target_qfs, qfs):
        for t_param, param in zip(target_qf.parameters(), qf.parameters()):
            t_param.data.copy_(t_param.data * (1.0 - algo.tau) + param.data * algo.tau)
