'''
Create distance model for Action Distances.
'''
import numpy as np
import torch
import torch.nn as nn
import abc
import hydra

class ActionDistance(object):
    @abc.abstractmethod
    def state_distance(self, s1, s2):
        """Main function of the agent that computes action distance between two states."""


    @abc.abstractmethod
    def traj_distance(self, t1, t2):
        """Main function of the agent that computes action distance between two trajectories.
            Makes use of self.state_distance"""


class ActionDistanceModel(ActionDistance):
    def __init__(self, model_cfg):
        self.model = hydra.utils.instantiate(model_cfg)

    def state_distance(self, s1, s2):
        return self.model(s1, s2)

    def traj_distance(self, t1, t2):
        # ensure that self.model can handle batched states.
        return self.model(t1, t2)





################################
################################
################################
#### ACTION DISTANCE LOSS ######
################################
################################
################################
class ActionDistanceLoss(nn.Module):
    """
    Compute a action distance embedding loss
    """

    def __init__(self, **kwargs):
        super().__init__()
        self.loss = nn.MSELoss(**kwargs)

    def forward(self, y, y_pred):
        """ Implements Action Distance Embedding loss"""
        # return self.loss(aR, pR, nR)
        loss = self.loss(y,y_pred)
        return loss


class L2EmbedLoss(nn.Module):
    """
    Compute an L2 embedding loss
    """

    def __init__(self, **kwargs):
        super().__init__()
        self.loss = nn.MSELoss(**kwargs)

    def forward(self, y, y_pred):
        """ Implements L2 Embedding loss"""
        # return self.loss(aR, pR, nR)
        loss = self.loss(y,y_pred)
        return loss

################################
################################
################################
#### REWARD TRIPLET LOSS #######
################################
################################
################################

class RewardTripletLoss(nn.Module):
    """
    Compute a form of triplet loss given anchor reward, positive reward and negative reward
    """
    def __init__(self, **kwargs):
        super().__init__()
        self.loss = nn.TripletMarginLoss(**kwargs)
    def forward(self, aR, pR, nR):
        """ Implements Triplet loss anchor aR, positive pR, negative nR"""
        return self.loss(aR, pR, nR)



class DyanamicsPredictionLoss(nn.Module):
    """
    Compute dynamics prediction loss (MSE between next state and predicted next state)
    """
    def __init__(self, **kwargs):
        super().__init__()
        self.loss = nn.MSELoss(**kwargs)

    def forward(self, y, y_pred):
        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y).to(y_pred.device)
        loss = self.loss(y, y_pred)
        return loss
