"""
Copyright 2022 Div Garg. All rights reserved.

Standalone IQ-Learn algorithm. See LICENSE for licensing terms.
"""
import torch
import torch.nn.functional as F

# Full IQ-Learn objective with other divergences and options
def iq_loss(agent, current_Q, current_v, next_v, weights, index, batch):
    args = agent.args
    gamma = agent.gamma
    obs, next_obs, action, env_reward, done, is_expert = batch

    loss_dict = {}
    # keep track of value of initial states
    v0 = agent.getV(obs[is_expert.squeeze(1), ...], weights[index]).mean(dim=0)
    # loss_dict['v0'] = v0.item()

    #  calculate 1st term for IQ loss
    #  -E_(ρ_expert)[Q(s, a) - γV(s')]
    y = (1 - done) * gamma * next_v
    reward =  (current_Q - y)[is_expert.squeeze(1)] @ weights[index].T

    with torch.no_grad():
        # Use different divergence functions (For χ2 divergence we instead add a third bellmann error-like term)
        if args.method.div == "hellinger":
            phi_grad = 1/(1+reward)**2
        elif args.method.div == "kl":
            # original dual form for kl divergence (sub optimal)
            phi_grad = torch.exp(-reward-1)
        elif args.method.div == "kl2":
            # biased dual form for kl divergence
            phi_grad = F.softmax(-reward, dim=0) * reward.shape[0]
        elif args.method.div == "kl_fix":
            # our proposed unbiased form for fixing kl divergence
            phi_grad = torch.exp(-reward)
        elif args.method.div == "js":
            # jensen–shannon
            phi_grad = torch.exp(-reward)/(2 - torch.exp(-reward))
        else:
            phi_grad = 1
    loss = -(phi_grad * reward).mean()
    loss_dict['softq_loss'] = loss.item()

    # calculate 2nd term for IQ loss, we show different sampling strategies
    if args.method.loss == "value_expert":
        # sample using only expert states (works offline)
        # E_(ρ)[Q(s,a) - γV(s')]
        value_loss = ((current_v - y)[is_expert.squeeze(1)] @ weights[index].T).mean()
        loss += value_loss
        loss_dict['value_loss'] = value_loss.item()

    elif args.method.loss == "value":
        # sample using expert and policy states (works online)
        # E_(ρ)[V(s) - γV(s')]
        value_loss = ((current_v - y) @ weights[index].T).mean()
        loss += value_loss
        loss_dict['value_loss'] = value_loss.item()

    elif args.method.loss == "v0":
        # alternate sampling using only initial states (works offline but usually suboptimal than `value_expert` startegy)
        # (1-γ)E_(ρ0)[V(s0)]
        v0_loss = (1 - gamma) * (v0 @ weights[index].T)
        loss += v0_loss
        loss_dict['v0_loss'] = v0_loss.item()

    # alternative sampling strategies for the sake of completeness but are usually suboptimal in practice
    # elif args.method.loss == "value_policy":
    #     # sample using only policy states
    #     # E_(ρ)[V(s) - γV(s')]
    #     value_loss = (current_v - y)[~is_expert].mean()
    #     loss += value_loss
    #     loss_dict['value_policy_loss'] = value_loss.item()

    # elif args.method.loss == "value_mix":
    #     # sample by weighted combination of expert and policy states
    #     # E_(ρ)[Q(s,a) - γV(s')]
    #     w = args.method.mix_coeff
    #     value_loss = (w * (current_v - y)[is_expert] +
    #                   (1-w) * (current_v - y)[~is_expert]).mean()
    #     loss += value_loss
    #     loss_dict['value_loss'] = value_loss.item()

    else:
        raise ValueError(f'This sampling method is not implemented: {args.method.type}')

    

    if args.method.grad_pen:
        # add a gradient penalty to loss (Wasserstein_1 metric)
        gp_loss = agent.critic_net.grad_pen(obs[is_expert.squeeze(1), ...],
                                            action[is_expert.squeeze(1), ...],
                                            obs[~is_expert.squeeze(1), ...],
                                            action[~is_expert.squeeze(1), ...],
                                            args.method.lambda_gp)
        loss_dict['gp_loss'] = gp_loss.item()
        loss += gp_loss

    if args.method.div == "chi" or args.method.chi:  # TODO: Deprecate method.chi argument for method.div
        # Use χ2 divergence (calculate the regularization term for IQ loss using expert states) (works offline)
        y = (1 - done) * gamma * next_v

        reward = (current_Q - y) @ weights[index].T
        chi2_loss = 1/(4 * args.method.alpha) * (reward**2)[is_expert].mean()
        loss += chi2_loss
        loss_dict['chi2_loss'] = chi2_loss.item()

    if args.method.regularize:
        # Use χ2 divergence (calculate the regularization term for IQ loss using expert and policy states) (works online)
        y = (1 - done) * gamma * next_v

        reward = (current_Q - y) @ weights[index].T
        chi2_loss = 1/(4 * args.method.alpha) * (reward**2).mean()
        loss += chi2_loss
        loss_dict['regularize_loss'] = chi2_loss.item()

    
    if "DoubleQ" in args.q_net._target_:
        cur_Q1= agent.critic(obs, action, weights[index], both=False)
        cur_Q2= agent.critic(obs, action, weights[(index-1) % weights.shape[0]], both=False)
        cur_Q3= agent.critic(obs, action, weights[(index+1) % weights.shape[0]], both=False)
    else:
        cur_Q1= agent.critic(obs, action, weights[index])
        cur_Q2= agent.critic(obs, action, weights[(index-1) % weights.shape[0]])
        cur_Q3= agent.critic(obs, action, weights[(index+1) % weights.shape[0]])

    if args.train.use_target:
        with torch.no_grad():
            next_V1 = agent.get_targetV(next_obs, weights[index])
            next_V2 = agent.get_targetV(next_obs, weights[(index-1) % weights.shape[0]])
            next_V3 = agent.get_targetV(next_obs, weights[(index+1) % weights.shape[0]])
    else:
        next_V1 = agent.getV(next_obs, weights[index])
        next_V2 = agent.getV(next_obs, weights[(index-1) % weights.shape[0]])
        next_V3 = agent.getV(next_obs, weights[(index+1) % weights.shape[0]])
    y1 = (1 - done) * gamma * next_V1
    y2 = (1 - done) * gamma * next_V2
    y3 = (1 - done) * gamma * next_V3
    reward1 = (cur_Q1 - y1)[is_expert.logical_not().squeeze(1)]
    reward2 = (cur_Q2 - y2)[is_expert.logical_not().squeeze(1)]
    reward3 = (cur_Q3 - y3)[is_expert.logical_not().squeeze(1)]
    constraint_loss = args.agent.constraint_coef * ((torch.norm(reward1-reward2, p=2) + torch.norm(reward1-reward3, p=2)) / reward1.numel())
    loss += constraint_loss
    loss_dict['constraint_loss'] = constraint_loss.item()


    loss_dict['total_loss'] = loss.item()
    return loss, loss_dict
