import pdb

import torch
import pdb


def iq_loss(agent, current_Q, current_v, next_v, z, gamma=0.99, expert_only=True, alpha=0.5, is_expert=None, v0=None, batch_size=512, grad_pen=0):
    # args = agent.args
    # gamma = agent.gamma
    # obs, next_obs, action, env_reward, done, is_expert = batch

    loss_dict = {}
    # pdb.set_trace()
    # keep track of value of initial states
    # v0 = agent.getV(z)
    # loss_dict['v0'] = v0.mean().item()

    #  calculate 1st term for IQ loss
    #  -E_(ρ_expert)[Q(s, a) - γV(s')]
    # pdb.set_trace()
    y = gamma * next_v
    reward = (current_Q - y)[is_expert.bool()]   # TODO: modify for combined training with non-expert data
    # reward = (current_Q - y)

    with torch.no_grad():
        # biased dual form for kl divergence
        phi_grad = 1

    loss = -(phi_grad * reward).mean()
    loss_dict['softq_loss'] = loss.item()
    # pdb.set_trace()

    # calculate 2nd term for IQ loss, we show different sampling strategies
    # if expert_only:
        # sample using only expert states (works offline)
        # E_(ρ)[Q(s,a) - γV(s')]
    # value_loss = ((current_v - y)[torch.nonzero(is_expert == 1)]).mean()
    if v0 is not None:
        value_loss = (1 - gamma) * v0
    else:
        value_loss = (current_v - y)[~is_expert.bool()].mean()
    # value_loss = (current_v - y).mean()
    # value_loss = (current_v - y)[~is_expert.bool()].mean()
    loss += value_loss
    loss_dict['value_loss'] = value_loss.item()
    
    # else:
    #     # sample using expert and policy states (works online)
    #     # E_(ρ)[V(s) - γV(s')]
    #     value_loss = (current_v - y).mean()
    #     loss += value_loss
    #     loss_dict['value_loss'] = value_loss.item()

    # v0_loss = (1 - gamma) * v0
    # loss += v0_loss
    # loss_dict['v0_loss'] = v0_loss.item()


    # alternative sampling strategies for the sake of completeness but are usually suboptimal in practice
    # elif args.method.loss == "value_policy":
    #     # sample using only policy states
    #     # E_(ρ)[V(s) - γV(s')]
    #     value_loss = (current_v - y)[~is_expert].mean()
    #     loss += value_loss
    #     loss_dict['value_policy_loss'] = value_loss.item()

    # elif args.method.loss == "value_mix":
    #     # sample by weighted combination of expert and policy states
    #     # E_(ρ)[Q(s,a) - γV(s')]
    #     w = args.method.mix_coeff
    #     value_loss = (w * (current_v - y)[is_expert] +
    #                   (1-w) * (current_v - y)[~is_expert]).mean()
    #     loss += value_loss
    #     loss_dict['value_loss'] = value_loss.item()


    # if args.method.grad_pen:
    #     # add a gradient penalty to loss (Wasserstein_1 metric)
    #     gp_loss = agent.critic_net.grad_pen(obs[is_expert.squeeze(1), ...],
    #                                         action[is_expert.squeeze(1), ...],
    #                                         obs[~is_expert.squeeze(1), ...],
    #                                         action[~is_expert.squeeze(1), ...],
    #                                         args.method.lambda_gp)
    #     loss_dict['gp_loss'] = gp_loss.item()
    #     loss += gp_loss

    # if expert_only:# TODO: Deprecate method.chi argument for method.div
        # Use χ2 divergence (calculate the regularization term for IQ loss using expert states) (works offline)
    y = gamma * next_v

    # reward = (current_Q - y)[is_expert.bool()]
    # chi2_loss = 1/(4 * alpha) * (reward**2).mean()
    # loss += chi2_loss
    # loss_dict['chi2_loss'] = chi2_loss.item()

    # else:
    #     # Use χ2 divergence (calculate the regularization term for IQ loss using expert and policy states) (works online)
    y = gamma * next_v

    reward = current_Q - y
    chi2_loss = 1/(4 * alpha) * (reward**2).mean()
    loss += chi2_loss
    loss_dict['regularize_loss'] = chi2_loss.item()
    loss += grad_pen

    loss_dict['total_loss'] = loss.item()
    return loss, reward, loss_dict
