"""
 - This module is the main part of :py:mod:`Algorithms<machina.algos>`.
    - If you'd like to only loss function of certain algorithm, you have only to change this module.
 - Inputs are :data:`batch` generated by :py:meth:`iterater<machina.traj.traj.Traj.iterate>`.
 - Output is loss.
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from machina.utils import detach_tensor_dict, get_device


def pg_clip(pol, batch, clip_param, ent_beta):
    """
    Policy Gradient with clipping.

    Parameters
    ----------
    pol : Pol
    batch : dict of torch.Tensor
    clip_param : float
    ent_beta : float
        entropy coefficient

    Returns
    -------
    pol_loss : torch.Tensor
    """
    obs = batch['obs']
    acs = batch['acs']
    advs = batch['advs']

    if pol.rnn:
        h_masks = batch['h_masks']
        out_masks = batch['out_masks']
    else:
        h_masks = None
        out_masks = torch.ones_like(advs) if not 'wei' in batch else batch['wei']

    pd = pol.pd

    old_llh = pd.llh(
        batch['acs'],
        batch,
    )

    pol.reset()
    _, _, pd_params = pol(obs, h_masks=h_masks)

    new_llh = pd.llh(acs, pd_params)
    ratio = torch.exp(new_llh - old_llh)
    pol_loss1 = - ratio * advs
    pol_loss2 = - torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advs
    pol_loss = torch.max(pol_loss1, pol_loss2)

    pol_loss = torch.sum(pol_loss * out_masks) / torch.sum(out_masks)

    ent = pd.ent(pd_params)
    pol_loss -= ent_beta * torch.sum(ent * out_masks) / torch.sum(out_masks)

    return pol_loss

def pg(pol, batch, ent_beta=0):
    """
    Policy Gradient.

    Parameters
    ----------
    pol : Pol
    batch : dict of torch.Tensor

    Returns
    -------
    pol_loss : torch.Tensor
    """
    obs = batch['obs']
    acs = batch['acs']
    advs = batch['advs']

    pd = pol.pd
    pol.reset()
    if pol.rnn:
        h_masks = batch['h_masks']
        out_masks = batch['out_masks']
        _, _, pd_params = pol(obs, h_masks=h_masks)
    else:
        out_masks = torch.ones_like(advs) if not 'wei' in batch else batch['wei']
        _, _, pd_params = pol(obs)

    llh = pol.pd.llh(acs, pd_params)

    pol_loss = - torch.sum(llh * advs * out_masks) / torch.sum(out_masks)

    ent = pd.ent(pd_params)
    pol_loss -= ent_beta * torch.sum(ent * out_masks) / torch.sum(out_masks)

    return pol_loss


def monte_carlo(vf, batch, clip_param=0.2, clip=False):
    """
    Montecarlo loss for V function.

    Parameters
    ----------
    vf : SVfunction
    batch : dict of torch.Tensor
    clip_param : float
    clip : bool

    Returns
    -------

    """
    obs = batch['obs']
    rets = batch['rets']

    vf.reset()
    if vf.rnn:
        h_masks = batch['h_masks']
        out_masks = batch['out_masks']
        vs, _ = vf(obs, h_masks=h_masks)
    else:
        out_masks = torch.ones_like(rets) if not 'wei' in batch else batch['wei']
        vs, _ = vf(obs)
        
    vfloss1 = (vs - rets)**2
    
    if clip:
        old_vs = batch['vs']
        vpredclipped = old_vs + \
            torch.clamp(vs - old_vs, -clip_param, clip_param)
        vfloss2 = (vpredclipped - rets)**2
        vf_loss = 0.5 * torch.sum(torch.max(vfloss1, vfloss2)
                                  * out_masks) / torch.sum(out_masks)
    else:
        vf_loss = 0.5 * torch.sum(vfloss1 * out_masks) / torch.sum(out_masks)
    return vf_loss