
import mcac.algos.core as core
import mcac.utils.pytorch_utils as ptu

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import url_benchmark.utils as utils

import copy
import os


class SSG(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim, rew_dim):
        super().__init__()

        self.forward_net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim))

        self.backward_net = nn.Sequential(nn.Linear(2 * obs_dim, hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(hidden_dim, action_dim),
                                          nn.Tanh())

        self.reward_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, rew_dim))

        self.apply(utils.weight_init)

    def forward(self, obs, action, next_obs, reward):
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        next_obs_hat = self.forward_net(torch.cat([obs, action], dim=-1))

        action_hat = self.backward_net(torch.cat([obs, next_obs_hat], dim=-1))

        forward_error = torch.norm(next_obs - next_obs_hat,
                                   dim=-1,
                                   p=2,
                                   keepdim=True)

        backward_error = torch.norm(action - action_hat,
                                    dim=-1,
                                    p=2,
                                    keepdim=True)

        reward_hat = self.reward_net(torch.cat([obs], dim=-1)).squeeze()

        reward_error = abs(reward - reward_hat)

        return forward_error, backward_error, reward_error