# import os
# from itertools import accumulate
# from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
#
# import matplotlib.pyplot as plt
# import numpy as np
# import torch
# import torch as th
# from stable_baselines3.common.torch_layers import create_mlp
# from stable_baselines3.common.utils import update_learning_rate
# from torch import nn
# from tqdm import tqdm
#
# from utils.model_utils import dirichlet_kl_divergence_loss
import copy
# class ConstraintNet(nn.Module):
#     def __init__(
#             self,
#             obs_dim: int,
#             acs_dim: int,
#             hidden_sizes: Tuple[int, ...],
#             batch_size: int,
#             lr_schedule: Callable[[float], float],
#             expert_obs: np.ndarray,
#             expert_acs: np.ndarray,
#             is_discrete: bool,
#             task: str = 'ICRL',
#             regularizer_coeff: float = 0.,
#             obs_select_dim: Optional[Tuple[int, ...]] = None,
#             acs_select_dim: Optional[Tuple[int, ...]] = None,
#             optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
#             optimizer_kwargs: Optional[Dict[str, Any]] = None,
#             no_importance_sampling: bool = False,
#             per_step_importance_sampling: bool = False,
#             clip_obs: Optional[float] = 10.,
#             initial_obs_mean: Optional[np.ndarray] = None,
#             initial_obs_var: Optional[np.ndarray] = None,
#             action_low: Optional[float] = None,
#             action_high: Optional[float] = None,
#             target_kl_old_new: float = -1,
#             target_kl_new_old: float = -1,
#             train_gail_lambda: Optional[bool] = False,
#             eps: float = 1e-5,
#             device: str = "cpu",
#             log_file=None
#     ):
#         super(ConstraintNet, self).__init__()
#         self.task = task
#         self.obs_dim = obs_dim
#         self.acs_dim = acs_dim
#         self.obs_select_dim = obs_select_dim
#         self.acs_select_dim = acs_select_dim
#         self._define_input_dims()
#
#         self.expert_obs = expert_obs
#         self.expert_acs = expert_acs
#
#         self.hidden_sizes = hidden_sizes
#         self.batch_size = batch_size
#         self.is_discrete = is_discrete
#         self.regularizer_coeff = regularizer_coeff
#         self.importance_sampling = not no_importance_sampling
#         self.per_step_importance_sampling = per_step_importance_sampling
#         self.clip_obs = clip_obs
#         self.device = device
#         self.eps = eps
#
#         self.train_gail_lambda = train_gail_lambda
#
#         if optimizer_kwargs is None:
#             optimizer_kwargs = {}
#             # Small values to avoid NaN in Adam optimizer
#             if optimizer_class == th.optim.Adam:
#                 optimizer_kwargs["eps"] = 1e-5
#         self.optimizer_kwargs = optimizer_kwargs
#         self.optimizer_class = optimizer_class
#         self.lr_schedule = lr_schedule
#
#         self.current_obs_mean = initial_obs_mean
#         self.current_obs_var = initial_obs_var
#         self.action_low = action_low
#         self.action_high = action_high
#
#         self.target_kl_old_new = target_kl_old_new
#         self.target_kl_new_old = target_kl_new_old
#
#         self.current_progress_remaining = 1.
#         self.log_file = log_file
#
#         self._build()
#
#     def _define_input_dims(self) -> None:
#         self.select_obs_dim = []
#         self.select_acs_dim = []
#         if self.obs_select_dim is None:
#             self.select_obs_dim += [i for i in range(self.obs_dim)]
#         elif self.obs_select_dim[0] != -1:
#             self.select_obs_dim += self.obs_select_dim
#         obs_len = len(self.select_obs_dim)
#         if self.acs_select_dim is None:
#             self.select_acs_dim += [i for i in range(self.acs_dim)]
#         elif self.acs_select_dim[0] != -1:
#             self.select_acs_dim += self.acs_select_dim
#         self.select_dim = self.select_obs_dim + [i+obs_len for i in self.select_acs_dim]
#         self.input_dims = len(self.select_dim)
#         assert self.input_dims > 0, ""
#
#     def _build(self) -> None:
#
#         # Create network and add sigmoid at the end
#         self.network = nn.Sequential(
#             *create_mlp(self.input_dims, 1, self.hidden_sizes),
#             nn.Sigmoid()
#         )
#         self.network.to(self.device)
#
#         # Build optimizer
#         if self.optimizer_class is not None:
#             self.optimizer = self.optimizer_class(self.parameters(), lr=self.lr_schedule(1), **self.optimizer_kwargs)
#         else:
#             self.optimizer = None
#         if self.train_gail_lambda:
#             self.criterion = nn.BCELoss()
#
#     def forward(self, x: th.tensor) -> th.tensor:
#         pred = self.network(x)
#         return pred
#
#     def cost_function(self, obs: np.ndarray, acs: np.ndarray) -> np.ndarray:
#         assert obs.shape[-1] == self.obs_dim, ""
#         if not self.is_discrete:
#             assert acs.shape[-1] == self.acs_dim, ""
#
#         x = self.prepare_data(obs, acs)
#         with th.no_grad():
#             out = self.__call__(x)
#
#         cost = 1 - out.detach().cpu().numpy()
#         return cost.squeeze(axis=-1)
#
#     def call_forward(self, x: np.ndarray):
#         with th.no_grad():
#             out = self.__call__(th.tensor(x, dtype=th.float32).to(self.device))
#         return out
#
#     def train_nn(
#             self,
#             iterations: np.ndarray,
#             nominal_obs: np.ndarray,
#             nominal_acs: np.ndarray,
#             episode_lengths: np.ndarray,
#             obs_mean: Optional[np.ndarray] = None,
#             obs_var: Optional[np.ndarray] = None,
#             current_progress_remaining: float = 1,
#     ) -> Dict[str, Any]:
#
#         # Update learning rate
#         self._update_learning_rate(current_progress_remaining)
#
#         # Update normalization stats
#         self.current_obs_mean = obs_mean
#         self.current_obs_var = obs_var
#
#         # Prepare data
#         nominal_data = self.prepare_data(nominal_obs, nominal_acs)
#         expert_data = self.prepare_data(self.expert_obs, self.expert_acs)
#
#         # Save current network predictions if using importance sampling
#         if self.importance_sampling:
#             with th.no_grad():
#                 start_preds = self.forward(nominal_data).detach()
#
#         early_stop_itr = iterations
#         loss = th.tensor(np.inf)
#
#         loss_all = []
#         expert_loss_all = []
#         nominal_loss_all = []
#         regularizer_loss_all = []
#         is_weights_all = []
#         nominal_preds_all = []
#         expert_preds_all = []
#
#         for itr in tqdm(range(iterations)):
#             # Compute IS weights
#             if self.importance_sampling:
#                 with th.no_grad():
#                     current_preds = self.forward(nominal_data).detach()
#                 is_weights, kl_old_new, kl_new_old = self.compute_is_weights(start_preds.clone(),
#                                                                              current_preds.clone(),
#                                                                              episode_lengths)
#                 # Break if kl is very large
#                 if ((self.target_kl_old_new != -1 and kl_old_new > self.target_kl_old_new) or
#                         (self.target_kl_new_old != -1 and kl_new_old > self.target_kl_new_old)):
#                     early_stop_itr = itr
#                     break
#             else:
#                 is_weights = th.ones(nominal_data.shape[0]).to(self.device)
#
#             # Do a complete pass on data
#             for nom_batch_indices, exp_batch_indices in self.get(nominal_data.shape[0], expert_data.shape[0]):
#                 # Get batch data
#                 nominal_batch = nominal_data[nom_batch_indices]
#                 expert_batch = expert_data[exp_batch_indices]
#                 is_batch = is_weights[nom_batch_indices][..., None]
#
#                 # Make predictions
#                 nominal_preds = self.__call__(nominal_batch)
#                 expert_preds = self.__call__(expert_batch)
#
#
#                 # Calculate loss
#                 if self.train_gail_lambda:
#                     nominal_loss = self.criterion(nominal_preds, th.zeros(*nominal_preds.size()))
#                     expert_loss = self.criterion(expert_preds, th.ones(*expert_preds.size()))
#                     regularizer_loss = th.tensor(0)
#                     loss = nominal_loss + expert_loss
#                 else:
#                     expert_loss = th.mean(th.log(expert_preds + self.eps))
#                     nominal_loss = th.mean(is_batch * th.log(nominal_preds + self.eps))
#                     regularizer_loss = th.mean(1 - expert_preds) + th.mean(1 - nominal_preds)
#                     loss = (-expert_loss + nominal_loss) + self.regularizer_coeff * regularizer_loss
#
#                 loss_all.append(loss)
#                 expert_loss_all.append(expert_loss)
#                 nominal_loss_all.append(nominal_loss)
#                 regularizer_loss_all.append(regularizer_loss)
#                 is_weights_all.append(is_weights)
#                 expert_preds_all.append(expert_preds)
#                 nominal_preds_all.append(nominal_preds)
#
#                 # Update
#                 self.optimizer.zero_grad()
#                 loss.backward()
#                 self.optimizer.step()
#
#         loss_all = torch.stack(loss_all, dim=0)
#         expert_loss_all = torch.stack(expert_loss_all, dim=0)
#         nominal_loss_all = torch.stack(nominal_loss_all, dim=0)
#         regularizer_loss_all = torch.stack(regularizer_loss_all, dim=0)
#         is_weights_all = torch.cat(is_weights_all, dim=0)
#         nominal_preds_all = torch.cat(nominal_preds_all, dim=0)
#         expert_preds_all = torch.cat(expert_preds_all, dim=0)
#
#         bw_metrics = {"backward/cn_loss": th.mean(loss_all).item(),
#                       "backward/expert_loss": th.mean(expert_loss_all).item(),
#                       "backward/unweighted_nominal_loss": th.mean(th.log(nominal_preds_all + self.eps)).item(),
#                       "backward/nominal_loss": th.mean(nominal_loss_all).item(),
#                       "backward/regularizer_loss": th.mean(regularizer_loss_all).item(),
#                       "backward/is_mean": th.mean(is_weights_all).detach().item(),
#                       "backward/is_max": th.max(is_weights_all).detach().item(),
#                       "backward/is_min": th.min(is_weights_all).detach().item(),
#                       "backward/data_shape": list(nominal_data.shape),
#                       "backward/nominal/preds_max": th.max(nominal_preds_all).item(),
#                       "backward/nominal/preds_min": th.min(nominal_preds_all).item(),
#                       "backward/nominal/preds_mean": th.mean(nominal_preds_all).item(),
#                       "backward/expert/preds_max": th.max(expert_preds_all).item(),
#                       "backward/expert/preds_min": th.min(expert_preds_all).item(),
#                       "backward/expert/preds_mean": th.mean(expert_preds_all).item(), }
#         if self.importance_sampling:
#             stop_metrics = {"backward/kl_old_new": kl_old_new.item(),
#                             "backward/kl_new_old": kl_new_old.item(),
#                             "backward/early_stop_itr": early_stop_itr}
#             bw_metrics.update(stop_metrics)
#
#         return bw_metrics
#
#     def compute_is_weights(self, preds_old: th.Tensor, preds_new: th.Tensor, episode_lengths: np.ndarray) -> th.tensor:
#         with th.no_grad():
#             n_episodes = len(episode_lengths)
#             cumulative = [0] + list(accumulate(episode_lengths))
#
#             ratio = (preds_new + self.eps) / (preds_old + self.eps)
#             prod = [th.prod(ratio[cumulative[j]:cumulative[j + 1]])
#                     for j in range(n_episodes)]
#             prod = th.tensor(prod)
#             normed = n_episodes * prod / (th.sum(prod) + self.eps)
#
#             if self.per_step_importance_sampling:
#                 is_weights = th.tensor(ratio / th.mean(ratio))
#             else:
#                 is_weights = []
#                 for length, weight in zip(episode_lengths, normed):
#                     is_weights += [weight] * length
#                 is_weights = th.tensor(is_weights)
#
#             # Compute KL(old, current)
#             kl_old_new = th.mean(-th.log(prod + self.eps))
#             # Compute KL(current, old)
#             prod_mean = th.mean(prod)
#             kl_new_old = th.mean((prod - prod_mean) * th.log(prod + self.eps) / (prod_mean + self.eps))
#
#         return is_weights.to(self.device), kl_old_new, kl_new_old
#
#     def prepare_data(
#             self,
#             obs: np.ndarray,
#             acs: np.ndarray,
#     ) -> th.tensor:
#
#         obs = self.normalize_obs(obs, self.current_obs_mean, self.current_obs_var, self.clip_obs)
#         acs = self.reshape_actions(acs)
#         acs = self.clip_actions(acs, self.action_low, self.action_high)
#
#         concat = self.select_appropriate_dims(select_dim=self.select_dim, x=np.concatenate([obs, acs], axis=-1))
#
#         return th.tensor(concat, dtype=th.float32).to(self.device)
#
#     def select_appropriate_dims(self, select_dim: list, x: Union[np.ndarray, torch.tensor]) -> Union[
#         np.ndarray, torch.tensor]:
#         return x[..., select_dim]
#
#
#     def normalize_obs(self, obs: np.ndarray, mean: Optional[float] = None, var: Optional[float] = None,
#                       clip_obs: Optional[float] = None) -> np.ndarray:
#         if mean is not None and var is not None:
#             mean, var = mean[None], var[None]
#             obs = (obs - mean) / np.sqrt(var + self.eps)
#         if clip_obs is not None:
#             obs = np.clip(obs, -self.clip_obs, self.clip_obs)
#
#         return obs
#
#     def reshape_actions(self, acs):
#         if self.is_discrete:
#             acs_ = acs.astype(int)
#             if len(acs.shape) > 1:
#                 acs_ = np.squeeze(acs_, axis=-1)
#             acs = np.zeros([acs.shape[0], self.acs_dim])
#             acs[np.arange(acs_.shape[0]), acs_] = 1.
#
#         return acs
#
#     def clip_actions(self, acs: np.ndarray, low: Optional[float] = None, high: Optional[float] = None) -> np.ndarray:
#         if high is not None and low is not None:
#             acs = np.clip(acs, low, high)
#
#         return acs
#
#     # def get(self, nom_size: int, exp_size: int) -> np.ndarray:
#     #     if self.batch_size is None:
#     #         yield np.arange(nom_size), np.arange(exp_size)
#     #     else:
#     #         size = min(nom_size, exp_size)
#     #         indices = np.random.permutation(size)
#     #
#     #         batch_size = self.batch_size
#     #         # Return everything, don't create minibatches
#     #         if batch_size is None:
#     #             batch_size = size
#     #
#     #         start_idx = 0
#     #         while start_idx < size:
#     #             batch_indices = indices[start_idx:start_idx + batch_size]
#     #             yield batch_indices, batch_indices
#     #             start_idx += batch_size
#
#     def get(self, nom_size: int, expert_size: int) -> np.ndarray:
#         if self.batch_size is None:
#             # Return everything, don't create minibatches
#             yield np.arange(nom_size), np.arange(expert_size)
#         else:
#             size = min(nom_size, expert_size)
#             expert_indices = np.random.permutation(expert_size)
#             # print(expert_indices)
#             nom_indices = np.random.permutation(nom_size)
#
#             batch_size = self.batch_size
#             start_idx = 0
#             while start_idx < size:
#                 batch_expert_indices = expert_indices[start_idx:start_idx + batch_size]
#                 batch_nom_indices = nom_indices[start_idx:start_idx + batch_size]
#                 yield batch_nom_indices, batch_expert_indices
#                 start_idx += batch_size
#
#     def _update_learning_rate(self, current_progress_remaining) -> None:
#         self.current_progress_remaining = current_progress_remaining
#         update_learning_rate(self.optimizer, self.lr_schedule(current_progress_remaining))
#
#     def save(self, save_path):
#         state_dict = dict(
#             cn_network=self.network.state_dict(),
#             cn_optimizer=self.optimizer.state_dict(),
#             obs_dim=self.obs_dim,
#             acs_dim=self.acs_dim,
#             is_discrete=self.is_discrete,
#             obs_select_dim=self.obs_select_dim,
#             acs_select_dim=self.acs_select_dim,
#             clip_obs=self.clip_obs,
#             obs_mean=self.current_obs_mean,
#             obs_var=self.current_obs_var,
#             action_low=self.action_low,
#             action_high=self.action_high,
#             device=self.device,
#             hidden_sizes=self.hidden_sizes
#         )
#         th.save(state_dict, save_path)
#
#     def _load(self, load_path):
#         state_dict = th.load(load_path)
#         if "cn_network" in state_dict:
#             self.network.load_state_dict(dic["cn_network"])
#         if "cn_optimizer" in state_dict and self.optimizer is not None:
#             self.optimizer.load_state_dict(dic["cn_optimizer"])
#
#     # Provide basic functionality to load this class
#     @classmethod
#     def load(
#             cls,
#             load_path: str,
#             obs_dim: Optional[int] = None,
#             acs_dim: Optional[int] = None,
#             is_discrete: bool = None,
#             obs_select_dim: Optional[Tuple[int, ...]] = None,
#             acs_select_dim: Optional[Tuple[int, ...]] = None,
#             clip_obs: Optional[float] = None,
#             obs_mean: Optional[np.ndarray] = None,
#             obs_var: Optional[np.ndarray] = None,
#             action_low: Optional[float] = None,
#             action_high: Optional[float] = None,
#             device: str = None
#     ):
#
#         state_dict = th.load(load_path)
#         # If value isn't specified, then get from state_dict
#         if obs_dim is None:
#             obs_dim = state_dict["obs_dim"]
#         if acs_dim is None:
#             acs_dim = state_dict["acs_dim"]
#         if is_discrete is None:
#             is_discrete = state_dict["is_discrete"]
#         if obs_select_dim is None:
#             obs_select_dim = state_dict["obs_select_dim"]
#         if acs_select_dim is None:
#             acs_select_dim = state_dict["acs_select_dim"]
#         if clip_obs is None:
#             clip_obs = state_dict["clip_obs"]
#         if obs_mean is None:
#             obs_mean = state_dict["obs_mean"]
#         if obs_var is None:
#             obs_var = state_dict["obs_var"]
#         if action_low is None:
#             action_low = state_dict["action_low"]
#         if action_high is None:
#             action_high = state_dict["action_high"]
#         if device is None:
#             device = state_dict["device"]
#
#         # Create network
#         hidden_sizes = state_dict["hidden_sizes"]
#         constraint_net = cls(
#             obs_dim=obs_dim,
#             acs_dim=acs_dim,
#             hidden_sizes=hidden_sizes,
#             batch_size=None,
#             lr_schedule=None,
#             expert_obs=None,
#             expert_acs=None,
#             is_discrete=is_discrete,
#             regularizer_coeff=None,
#             obs_select_dim=obs_select_dim,
#             acs_select_dim=acs_select_dim,
#             clip_obs=clip_obs,
#             initial_obs_mean=obs_mean,
#             initial_obs_var=obs_var,
#             action_low=action_low,
#             action_high=action_high,
#             optimizer_class=None,
#             device=device
#         )
#         constraint_net.network.load_state_dict(state_dict["cn_network"])
#
#         return constraint_net


# =====================================================================
# Plotting utilities
# =====================================================================

# The following functions exploit knowledge about the environment


import os
from itertools import accumulate
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union

import matplotlib.pyplot as plt
import numpy as np
import torch as th
from stable_baselines3.common.torch_layers import create_mlp
from stable_baselines3.common.utils import update_learning_rate
from torch import nn
from tqdm import tqdm

from utils.data_utils import idx2vector


class ConstraintNet(nn.Module):
    def __init__(
            self,
            obs_dim: int,
            acs_dim: int,
            hidden_sizes: Tuple[int, ...],
            batch_size: int,
            lr_schedule: Callable[[float], float],
            expert_obs: np.ndarray,
            expert_acs: np.ndarray,
            is_discrete: bool,
            task: str = 'ICRL',
            regularizer_coeff: float = 0.,
            obs_select_dim: Optional[Tuple[int, ...]] = None,
            acs_select_dim: Optional[Tuple[int, ...]] = None,
            optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
            optimizer_kwargs: Optional[Dict[str, Any]] = None,
            no_importance_sampling: bool = True,
            per_step_importance_sampling: bool = False,
            clip_obs: Optional[float] = 10.,
            initial_obs_mean: Optional[np.ndarray] = None,
            initial_obs_var: Optional[np.ndarray] = None,
            action_low: Optional[float] = None,
            action_high: Optional[float] = None,
            target_kl_old_new: float = -1,
            target_kl_new_old: float = -1,
            train_gail_lambda: Optional[bool] = False,
            eps: float = 1e-5,
            recon_obs: bool = False,
            env_configs: dict = None,
            device: str = "cpu",
            log_file=None,
    ):
        super(ConstraintNet, self).__init__()
        self.task = task
        self.obs_dim = obs_dim
        self.acs_dim = acs_dim
        self.obs_select_dim = obs_select_dim
        self.acs_select_dim = acs_select_dim
        self._define_input_dims()
        self.expert_obs = expert_obs
        self.expert_acs = expert_acs
        self.hidden_sizes = hidden_sizes
        self.batch_size = batch_size
        self.is_discrete = is_discrete
        self.regularizer_coeff = regularizer_coeff
        self.importance_sampling = not no_importance_sampling
        self.per_step_importance_sampling = per_step_importance_sampling
        self.clip_obs = clip_obs
        self.device = device
        self.eps = eps
        self.recon_obs = recon_obs
        self.env_configs = env_configs
        self.train_gail_lambda = train_gail_lambda

        if optimizer_kwargs is None:
            optimizer_kwargs = {}
            # Small values to avoid NaN in Adam optimizer
            if optimizer_class == th.optim.Adam:
                optimizer_kwargs["eps"] = 1e-5
        self.optimizer_kwargs = optimizer_kwargs
        self.optimizer_class = optimizer_class
        self.lr_schedule = lr_schedule

        self.current_obs_mean = initial_obs_mean
        self.current_obs_var = initial_obs_var
        self.action_low = action_low
        self.action_high = action_high

        self.target_kl_old_new = target_kl_old_new
        self.target_kl_new_old = target_kl_new_old

        self.current_progress_remaining = 1.
        self.log_file = log_file

        self._build()

    # def _define_input_dims(self) -> None:
    #     self.select_dim = []
    #     if self.obs_select_dim is None:
    #         self.select_dim += [i for i in range(self.obs_dim)]
    #     elif self.obs_select_dim[0] != -1:
    #         self.select_dim += self.obs_select_dim
    #     obs_len = len(self.select_obs_dim)
    #     if self.acs_select_dim is None:
    #         self.select_dim += [obs_len+i for i in range(self.acs_dim)]
    #     elif self.acs_select_dim[0] != -1:
    #         self.select_dim += [obs_len+i for i in self.acs_select_dim]
    #     assert len(self.select_dim) > 0, ""
    #
    #     self.input_dims = len(self.select_dim)

    def _define_input_dims(self) -> None:
        self.input_obs_dim = []
        self.input_acs_dim = []
        if self.obs_select_dim is None:
            self.input_obs_dim += [i for i in range(self.obs_dim)]
        elif self.obs_select_dim[0] != -1:
            self.input_obs_dim += self.obs_select_dim
        obs_len = len(self.input_obs_dim)
        if self.acs_select_dim is None:
            self.input_acs_dim += [i for i in range(self.acs_dim)]
        elif self.acs_select_dim[0] != -1:
            self.input_acs_dim += self.acs_select_dim
        self.select_dim = self.input_obs_dim + [i + obs_len for i in self.input_acs_dim]
        self.input_dims = len(self.select_dim)
        assert self.input_dims > 0, ""

    def _build(self) -> None:

        # Create network and add sigmoid at the end
        self.network = nn.Sequential(
            *create_mlp(self.input_dims, 1, self.hidden_sizes),
            nn.Sigmoid()
        )
        self.network.to(self.device)

        # Build optimizer
        if self.optimizer_class is not None:
            self.optimizer = self.optimizer_class(self.parameters(), lr=self.lr_schedule(1), **self.optimizer_kwargs)
        else:
            self.optimizer = None
        if self.train_gail_lambda:
            self.criterion = nn.BCELoss()

    def forward(self, x: th.tensor) -> th.tensor:
        return self.network(x)

    def cost_function(self, obs: np.ndarray, acs: np.ndarray, force_mode: str = None) -> np.ndarray:
        assert self.recon_obs or obs.shape[-1] == self.obs_dim, ""
        if not self.is_discrete:
            assert acs.shape[-1] == self.acs_dim, ""

        x = self.prepare_data(obs, acs)
        with th.no_grad():
            out = self.__call__(x)
        cost = 1 - out.detach().cpu().numpy()
        return cost.squeeze(axis=-1)

    def call_forward(self, x: np.ndarray):
        with th.no_grad():
            out = self.__call__(th.tensor(x, dtype=th.float32).to(self.device))
        return out

    def train_gridworld_nn(
            self,
            iterations: np.ndarray,
            nominal_obs: np.ndarray,
            nominal_acs: np.ndarray,
            episode_lengths: np.ndarray,
            obs_mean: Optional[np.ndarray] = None,
            obs_var: Optional[np.ndarray] = None,
            env_configs: Dict = None,
            current_progress_remaining: float = 1,
    ) -> Dict[str, Any]:
        # Update learning rate
        self._update_learning_rate(current_progress_remaining)

        # Update normalization stats
        self.current_obs_mean = obs_mean
        self.current_obs_var = obs_var
        # Prepare data
        nominal_data_games = [self.prepare_data(nominal_obs[i], nominal_acs[i])
                              for i in range(len(nominal_obs))]
        expert_data_games = [self.prepare_data(self.expert_obs[i], self.expert_acs[i])
                             for i in range(len(self.expert_obs))]
        early_stop_itr = iterations
        # loss = th.tensor(np.inf)
        for itr in tqdm(range(iterations)):
            for gid in range(min(len(nominal_data_games), len(expert_data_games))):
                nominal_data = nominal_data_games[gid]
                expert_data = expert_data_games[gid]

                # Save current network predictions if using importance sampling
                if self.importance_sampling:
                    with th.no_grad():
                        start_preds = self.forward(nominal_data).detach()

                # Compute IS weights
                if self.importance_sampling:
                    with th.no_grad():
                        current_preds = self.forward(nominal_data).detach()
                    is_weights, kl_old_new, kl_new_old = self.compute_is_weights(start_preds.clone(),
                                                                                 current_preds.clone(),
                                                                                 episode_lengths)
                    # Break if kl is very large
                    if ((self.target_kl_old_new != -1 and kl_old_new > self.target_kl_old_new) or
                            (self.target_kl_new_old != -1 and kl_new_old > self.target_kl_new_old)):
                        early_stop_itr = itr
                        break
                else:
                    is_weights = th.ones(nominal_data.shape[0]).to(self.device)

                nominal_preds_all = []
                expert_preds_all = []
                # Do a complete pass on the game data
                for nom_batch_indices, exp_batch_indices in self.get(nominal_data.shape[0], expert_data.shape[0]):
                    # Get batch data
                    nominal_batch = nominal_data[nom_batch_indices]
                    expert_batch = expert_data[exp_batch_indices]
                    is_batch = is_weights[nom_batch_indices][..., None]

                    # Make predictions
                    nominal_preds = self.__call__(nominal_batch)
                    nominal_preds_all.append(nominal_preds)
                    expert_preds = self.__call__(expert_batch)
                    expert_preds_all.append(expert_preds)

                # Calculate loss
                expert_preds_all = th.concat(expert_preds_all)
                nominal_preds_all = th.concat(nominal_preds_all)
                if self.train_gail_lambda:
                    nominal_preds_prod = nominal_preds_all.prod(dim=0)
                    nominal_loss = self.criterion(nominal_preds_prod,
                                                  th.zeros(*nominal_preds_prod.size()).to(self.device))
                    expert_preds_prod = expert_preds_all.prod(dim=0)
                    expert_loss = self.criterion(expert_preds_prod, th.ones(*expert_preds_prod.size()).to(self.device))
                    regularizer_loss = th.tensor(0)
                    loss = nominal_loss + expert_loss
                    # expert_loss = th.sum(th.log(expert_preds_all + self.eps))
                    # nominal_loss = th.sum(th.log(nominal_preds_all + self.eps))
                    # regularizer_loss = th.tensor(0)
                    # loss = (-expert_loss + nominal_loss) + regularizer_loss
                else:
                    expert_loss = th.sum(th.log(expert_preds_all + self.eps))
                    nominal_loss = th.sum(is_batch * th.log(nominal_preds_all + self.eps))
                    regularizer_loss = self.regularizer_coeff * (th.mean(1 - expert_preds_all)
                                                                 + th.mean(1 - nominal_preds_all))
                    loss = (-expert_loss + nominal_loss) + regularizer_loss
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        bw_metrics = {"backward/cn_loss": loss.item(),
                      "backward/expert_loss": expert_loss.item(),
                      "backward/unweighted_nominal_loss": th.mean(th.log(nominal_preds + self.eps)).item(),
                      "backward/nominal_loss": nominal_loss.item(),
                      "backward/regularizer_loss": regularizer_loss.item(),
                      "backward/is_mean": th.mean(is_weights).detach().item(),
                      "backward/is_max": th.max(is_weights).detach().item(),
                      "backward/is_min": th.min(is_weights).detach().item(),
                      "backward/nominal_preds_max": th.max(nominal_preds).item(),
                      "backward/nominal_preds_min": th.min(nominal_preds).item(),
                      "backward/nominal_preds_mean": th.mean(nominal_preds).item(),
                      "backward/expert_preds_max": th.max(expert_preds).item(),
                      "backward/expert_preds_min": th.min(expert_preds).item(),
                      "backward/expert_preds_mean": th.mean(expert_preds).item(), }
        if self.importance_sampling:
            stop_metrics = {"backward/kl_old_new": kl_old_new.item(),
                            "backward/kl_new_old": kl_new_old.item(),
                            "backward/early_stop_itr": early_stop_itr}
            bw_metrics.update(stop_metrics)

        return bw_metrics

    def train_nn(
            self,
            iterations: np.ndarray,
            nominal_obs: np.ndarray,
            nominal_acs: np.ndarray,
            episode_lengths: np.ndarray,
            obs_mean: Optional[np.ndarray] = None,
            obs_var: Optional[np.ndarray] = None,
            current_progress_remaining: float = 1,
    ) -> Dict[str, Any]:

        # Update learning rate
        self._update_learning_rate(current_progress_remaining)

        # Update normalization stats
        self.current_obs_mean = obs_mean
        self.current_obs_var = obs_var

        # Prepare data
        nominal_data = self.prepare_data(nominal_obs, nominal_acs)
        expert_data = self.prepare_data(self.expert_obs, self.expert_acs)

        # Save current network predictions if using importance sampling
        if self.importance_sampling:
            with th.no_grad():
                start_preds = self.forward(nominal_data).detach()

        early_stop_itr = iterations
        loss = th.tensor(np.inf)
        for itr in tqdm(range(iterations)):
            # Compute IS weights
            if self.importance_sampling:
                with th.no_grad():
                    current_preds = self.forward(nominal_data).detach()
                is_weights, kl_old_new, kl_new_old = self.compute_is_weights(start_preds.clone(), current_preds.clone(),
                                                                             episode_lengths)
                # Break if kl is very large
                if ((self.target_kl_old_new != -1 and kl_old_new > self.target_kl_old_new) or
                        (self.target_kl_new_old != -1 and kl_new_old > self.target_kl_new_old)):
                    early_stop_itr = itr
                    break
            else:
                is_weights = th.ones(nominal_data.shape[0]).to(self.device)

            # Do a complete pass on data
            for nom_batch_indices, exp_batch_indices in self.get(nominal_data.shape[0], expert_data.shape[0]):
                # Get batch data
                nominal_batch = nominal_data[nom_batch_indices]
                expert_batch = expert_data[exp_batch_indices]
                is_batch = is_weights[nom_batch_indices][..., None]

                # Make predictions
                nominal_preds = self.__call__(nominal_batch)
                expert_preds = self.__call__(expert_batch)

                # Calculate loss
                if self.train_gail_lambda:
                    nominal_loss = self.criterion(nominal_preds, th.zeros(*nominal_preds.size()).to(self.device))
                    expert_loss = self.criterion(expert_preds, th.ones(*expert_preds.size()).to(self.device))
                    regularizer_loss = th.tensor(0)
                    loss = nominal_loss + expert_loss
                else:
                    expert_loss = th.mean(th.log(expert_preds + self.eps))
                    nominal_loss = th.mean(is_batch * th.log(nominal_preds + self.eps))
                    regularizer_loss = self.regularizer_coeff * (th.mean(1 - expert_preds) + th.mean(1 - nominal_preds))
                    loss = (-expert_loss + nominal_loss) + regularizer_loss

                # Update
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        bw_metrics = {"backward/cn_loss": loss.item(),
                      "backward/expert_loss": expert_loss.item(),
                      "backward/unweighted_nominal_loss": th.mean(th.log(nominal_preds + self.eps)).item(),
                      "backward/nominal_loss": nominal_loss.item(),
                      "backward/regularizer_loss": regularizer_loss.item(),
                      "backward/is_mean": th.mean(is_weights).detach().item(),
                      "backward/is_max": th.max(is_weights).detach().item(),
                      "backward/is_min": th.min(is_weights).detach().item(),
                      "backward/nominal_preds_max": th.max(nominal_preds).item(),
                      "backward/nominal_preds_min": th.min(nominal_preds).item(),
                      "backward/nominal_preds_mean": th.mean(nominal_preds).item(),
                      "backward/expert_preds_max": th.max(expert_preds).item(),
                      "backward/expert_preds_min": th.min(expert_preds).item(),
                      "backward/expert_preds_mean": th.mean(expert_preds).item(), }
        if self.importance_sampling:
            stop_metrics = {"backward/kl_old_new": kl_old_new.item(),
                            "backward/kl_new_old": kl_new_old.item(),
                            "backward/early_stop_itr": early_stop_itr}
            bw_metrics.update(stop_metrics)

        return bw_metrics

    def compute_is_weights(self, preds_old: th.Tensor, preds_new: th.Tensor, episode_lengths: np.ndarray) -> th.tensor:
        with th.no_grad():
            n_episodes = len(episode_lengths)
            cumulative = [0] + list(accumulate(episode_lengths))

            ratio = (preds_new + self.eps) / (preds_old + self.eps)
            prod = [th.prod(ratio[cumulative[j]:cumulative[j + 1]]) for j in range(n_episodes)]
            prod = th.tensor(prod)
            normed = n_episodes * prod / (th.sum(prod) + self.eps)

            if self.per_step_importance_sampling:
                is_weights = th.tensor(ratio / th.mean(ratio))
            else:
                is_weights = []
                for length, weight in zip(episode_lengths, normed):
                    is_weights += [weight] * length
                is_weights = th.tensor(is_weights)

            # Compute KL(old, current)
            kl_old_new = th.mean(-th.log(prod + self.eps))
            # Compute KL(current, old)
            prod_mean = th.mean(prod)
            kl_new_old = th.mean((prod - prod_mean) * th.log(prod + self.eps) / (prod_mean + self.eps))

        return is_weights.to(self.device), kl_old_new, kl_new_old

    def prepare_data(
            self,
            obs: np.ndarray,
            acs: np.ndarray,
    ) -> th.tensor:

        if self.recon_obs:
            obs = idx2vector(obs, height=self.env_configs['map_height'], width=self.env_configs['map_width'])
        else:
            obs = copy.copy(obs)

        obs = self.normalize_obs(obs, self.current_obs_mean, self.current_obs_var, self.clip_obs)
        acs = self.reshape_actions(acs)
        acs = self.clip_actions(acs, self.action_low, self.action_high)
        concat = self.select_appropriate_dims(np.concatenate([obs, acs], axis=-1))
        del obs, acs

        return th.tensor(concat, dtype=th.float32).to(self.device)

    def select_appropriate_dims(self, x: Union[np.ndarray, th.tensor]) -> Union[np.ndarray, th.tensor]:
        return x[..., self.select_dim]

    def normalize_obs(self, obs: np.ndarray, mean: Optional[float] = None, var: Optional[float] = None,
                      clip_obs: Optional[float] = None) -> np.ndarray:
        if mean is not None and var is not None:
            mean, var = mean[None], var[None]
            obs = (obs - mean) / np.sqrt(var + self.eps)
        if clip_obs is not None:
            obs = np.clip(obs, -self.clip_obs, self.clip_obs)

        return obs

    def reshape_actions(self, acs):
        if self.is_discrete:
            acs_ = acs.astype(int)
            if len(acs.shape) > 1:
                acs_ = np.squeeze(acs_, axis=-1)
            acs = np.zeros([acs.shape[0], self.acs_dim])
            acs[np.arange(acs_.shape[0]), acs_] = 1.

        return acs

    def clip_actions(self, acs: np.ndarray, low: Optional[float] = None, high: Optional[float] = None) -> np.ndarray:
        if high is not None and low is not None:
            acs = np.clip(acs, low, high)

        return acs

    def get(self, nom_size: int, exp_size: int) -> np.ndarray:
        if self.batch_size is None:
            yield np.arange(nom_size), np.arange(exp_size)
        else:
            size = min(nom_size, exp_size)
            expert_indices = np.random.permutation(exp_size)
            nom_indices = np.random.permutation(nom_size)
            start_idx = 0
            while start_idx < size:
                batch_expert_indices = expert_indices[start_idx:start_idx + self.batch_size]
                batch_nom_indices = nom_indices[start_idx:start_idx + self.batch_size]
                yield batch_nom_indices, batch_expert_indices
                start_idx += self.batch_size

            # size = min(nom_size, exp_size)
            # indices = np.random.permutation(size)
            # batch_size = self.batch_size
            # # Return everything, don't create minibatches
            # if batch_size is None:
            #     batch_size = size
            #
            # start_idx = 0
            # while start_idx < size:
            #     batch_indices = indices[start_idx:start_idx + batch_size]
            #     yield batch_indices, batch_indices
            #     start_idx += batch_size

    def _update_learning_rate(self, current_progress_remaining) -> None:
        self.current_progress_remaining = current_progress_remaining
        update_learning_rate(self.optimizer, self.lr_schedule(current_progress_remaining))
        print("Learning rate is {0}.".format(self.lr_schedule(current_progress_remaining)),
              file=self.log_file,
              flush=True)

    def save(self, save_path):
        state_dict = dict(
            cn_network=self.network.state_dict(),
            cn_optimizer=self.optimizer.state_dict(),
            obs_dim=self.obs_dim,
            acs_dim=self.acs_dim,
            is_discrete=self.is_discrete,
            obs_select_dim=self.obs_select_dim,
            acs_select_dim=self.acs_select_dim,
            clip_obs=self.clip_obs,
            obs_mean=self.current_obs_mean,
            obs_var=self.current_obs_var,
            action_low=self.action_low,
            action_high=self.action_high,
            device=self.device,
            hidden_sizes=self.hidden_sizes
        )
        th.save(state_dict, save_path)

    def _load(self, load_path):
        state_dict = th.load(load_path)
        if "cn_network" in state_dict:
            self.network.load_state_dict(dic["cn_network"])
        if "cn_optimizer" in state_dict and self.optimizer is not None:
            self.optimizer.load_state_dict(dic["cn_optimizer"])

    # Provide basic functionality to load this class
    @classmethod
    def load(
            cls,
            load_path: str,
            obs_dim: Optional[int] = None,
            acs_dim: Optional[int] = None,
            is_discrete: bool = None,
            obs_select_dim: Optional[Tuple[int, ...]] = None,
            acs_select_dim: Optional[Tuple[int, ...]] = None,
            clip_obs: Optional[float] = None,
            obs_mean: Optional[np.ndarray] = None,
            obs_var: Optional[np.ndarray] = None,
            action_low: Optional[float] = None,
            action_high: Optional[float] = None,
            device: str = "auto"
    ):

        state_dict = th.load(load_path)
        # If value isn't specified, then get from state_dict
        if obs_dim is None:
            obs_dim = state_dict["obs_dim"]
        if acs_dim is None:
            acs_dim = state_dict["acs_dim"]
        if is_discrete is None:
            is_discrete = state_dict["is_discrete"]
        if obs_select_dim is None:
            obs_select_dim = state_dict["obs_select_dim"]
        if acs_select_dim is None:
            acs_select_dim = state_dict["acs_select_dim"]
        if clip_obs is None:
            clip_obs = state_dict["clip_obs"]
        if obs_mean is None:
            obs_mean = state_dict["obs_mean"]
        if obs_var is None:
            obs_var = state_dict["obs_var"]
        if action_low is None:
            action_low = state_dict["action_low"]
        if action_high is None:
            action_high = state_dict["action_high"]
        if device is None:
            device = state_dict["device"]

        # Create network
        hidden_sizes = state_dict["hidden_sizes"]
        constraint_net = cls(
            obs_dim=obs_dim,
            acs_dim=acs_dim,
            hidden_sizes=hidden_sizes,
            batch_size=None,
            lr_schedule=None,
            expert_obs=None,
            expert_acs=None,
            optimizer_class=None,
            is_discrete=is_discrete,
            obs_select_dim=obs_select_dim,
            acs_select_dim=acs_select_dim,
            clip_obs=clip_obs,
            initial_obs_mean=obs_mean,
            initial_obs_var=obs_var,
            action_low=action_low,
            action_high=action_high,
            device=device
        )
        constraint_net.network.load_state_dict(state_dict["cn_network"])

        return constraint_net

# =====================================================================
# Plotting utilities
# =====================================================================

# The following functions exploit knowledge about the environment

# def plot_constraints(cost_function, env, env_id, select_dim, obs_dim, acs_dim,
#                      save_name, observations=None):
#     if env_id in ['PointEnv-v0',
#                   'PointEnvTest-v0',
#                   'HCWithPos-v0',
#                   'AntWallTest-v0',
#                   'HCWithPosTest-v0',
#                   'PointCircle-v0',
#                   'PointCircleTest-v0',
#                   'WalkerWithPos-v0',
#                   'WalkerWithPosTest-v0']:
#         plot_for_gym_envs(cost_function, env, env_id, select_dim, obs_dim, acs_dim,
#                           save_name, observations)
#     elif env_id in ["D2B-v0", "DD2B-v0", "CDD2B-v0", "TCDD2B-v0",
#                     "D3B-v0", "DD3B-v0", "CDD3B-v0", "TCDD3B-v0"]:
#         plot_for_bridges_envs(cost_function, save_name, observations)
#     elif env_id in ["LGW-v0", "CLGW-v0"]:
#         pass
#         # plot_for_lap_env(cost_function, save_name, observations)
#     else:
#         print("Env id not recognized; skipping plotting of constraint net")
#
#
# def plot_for_gym_envs(cost_function, env, env_id, select_dim, obs_dim, acs_dim,
#                       save_name, obs=None):
#     if len(select_dim) > 2:
#         print("Cannot plot; constraint net has more than two dimensions.")
#         return
#
#     x_range = [-12, 12] if "Point" in env_id else [-20, 20]
#
#     if len(select_dim) == 1:
#         fig, ax = plt.subplots(1, 2, figsize=(30, 15))
#         num_points = 1000
#         obs_all = np.linspace(x_range[0], x_range[1], num_points)[..., None]
#         obs_all = np.concatenate((obs_all, np.zeros((num_points, obs_dim - 1))), axis=-1)
#
#         action = np.zeros((num_points, acs_dim))
#         preds = 1 - cost_function(obs_all, action)
#         ax[0].plot(obs_all, preds)
#         if obs is not None:
#             ax[0].scatter(obs[..., 0], 0.2 + np.zeros(obs.shape[0]))
#             ax[1].hist(obs[:, 0], bins=40, range=(-20, 20))
#             ax[1].set_axisbelow(True)
#             # Turn on the minor TICKS, which are required for the minor GRID
#             ax[1].minorticks_on()
#             ax[1].grid(which='major', linestyle='-', linewidth='0.5', color='red')
#             ax[1].grid(which='minor', linestyle=':', linewidth='0.5', color='black')
#         ax[0].set_ylim([0, 1])
#         ax[0].set_xlim(x_range)
#         ax[0].set_axisbelow(True)
#         # Turn on the minor TICKS, which are required for the minor GRID
#         ax[0].minorticks_on()
#         ax[0].grid(which='major', linestyle='-', linewidth='0.5', color='red')
#         ax[0].grid(which='minor', linestyle=':', linewidth='0.5', color='black')
#         fig.savefig(save_name)
#         plt.close(fig=fig)
#
#     elif len(select_dim) == 2:
#         fig, ax = plt.subplots(1, 1, figsize=(30, 15))
#         r = np.arange(-20, 20, 0.1)
#         X, Y = np.meshgrid(r, r)
#         obs_all = np.concatenate([X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=-1)
#         obs_all = np.concatenate((obs_all, np.zeros((np.size(X), obs_dim - 2))), axis=-1)
#
#         action = np.zeros((np.size(X), acs_dim))
#         outs = 1 - cost_function(obs_all, action)
#         im = ax.imshow(outs.reshape(X.shape), extent=[-20, 20, -20, 20], cmap='jet_r',
#                        vmin=0, vmax=1, origin='lower')
#         # axs[action].set_title('Action: %s' % action_desc[action])
#         fig.colorbar(im, ax=ax)
#
#         if obs is not None:
#             obs = np.clip(obs, -20, 20)
#             ax.scatter(obs[..., 0], obs[..., 1], clip_on=False)
#         ax.set_ylim([-20, 20])
#         ax.set_xlim([-20, 20])
#         plt.grid('on')
#         fig.savefig(save_name)
#         plt.close(fig=fig)
#
#
# def plot_for_bridges_envs(cost_function, save_name, observations=None):
#     if observations is not None:
#         # Unnormalize
#         observations = 10 * (observations + 1)
#     action_desc = ['right', 'left', 'up', 'bottom']
#     r = np.arange(0, 20, 0.1)
#     X, Y = np.meshgrid(r, r)
#     obs = np.concatenate([X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=-1)
#     # Normalize
#     obs = obs / 10 - 1
#
#     fig, axs = plt.subplots(2, 2)
#     axs = [ax for axs_ in axs for ax in axs_]
#     fig.set_size_inches(20, 20)
#     for action in range(4):
#         acs = action * np.ones([obs.shape[0], 1])
#         outs = 1 - cost_function(obs, acs)
#         im = axs[action].imshow(outs.reshape(X.shape), extent=[0, 20, 0, 20], cmap='jet_r',
#                                 vmin=0, vmax=1, origin='lower')
#         axs[action].set_title('Action: %s' % action_desc[action])
#         fig.colorbar(im, ax=axs[action])
#
#         if observations is not None:
#             axs[action].scatter(observations[:, 0], observations[:, 1], clip_on=False)
#
#         axs[action].set_xlim(0, 20)
#         axs[action].set_ylim(0, 20)
#
#     fig.savefig(save_name)
#     plt.close(fig=fig)
#
#
# def plot_for_lap_env(cost_function, save_name, obs=None):
#     # Assumes a lap size of 11
#     LAP_SIZE = 11
#     action_desc = ['Forward', 'Backward']
#     number_of_cells = (LAP_SIZE - 1) * 4
#     all_obs = np.arange(number_of_cells)[:, None]
#     fig, ax = plt.subplots(1, 2, figsize=(30, 15))
#     if obs is not None: obs = np.round((obs[:, :1] + 1) * (number_of_cells / 2))
#     for action in range(2):
#         ac = action * np.ones([all_obs.shape[0], 1])
#         outs = 1 - cost_function(all_obs, ac)
#         outs = _reshape_lap_to_grid(outs)
#         # Plotting
#         c = ax[action].pcolor(outs, edgecolors='w', linewidths=2, cmap='jet_r', vmin=-0.0, vmax=1.0)
#         ax[action].set_title('Action: %s' % action_desc[action])
#         if obs is not None:
#             co_ords = []
#             for i in range(obs.shape[0]):
#                 for j in range(obs.shape[1]):
#                     co_ords.append(_idx_to_xy(obs[i, j]))
#             x, y = zip(*co_ords)
#             x, y = np.array(x) + 0.5, np.array(y) + 0.5
#             ax[action].scatter(x, y)
#
#     plt.axis('off')
#     fig.savefig(save_name)
#     plt.close(fig=fig)
