import torch
import numpy as np
import warnings


class LinearFeatureBaseline(object):
    """
    Linear (polynomial) time-state dependent return baseline model
    (see. Duan et al. 2016, "Benchmarking Deep Reinforcement Learning for Continuous Control", ICML)
    https://github.com/rll/rllab/
    """

    def __init__(self, reg_coeff=1e-5, *args, **kwargs):
        self._coeffs = None
        self._reg_coeff = reg_coeff

    def predict(self, traj_data):
        """
        Predicts the linear reward baselines estimates for a provided trajectory / path.
        If the baseline is not fitted - returns zero baseline

        Args:
           traj_data (dict): dict of lists/numpy array containing trajectory / path information
                 such as "observations", "rewards", ...

        Returns:
             Updated trajectory data with baselines
        """

        if self._coeffs is None:
            traj_data["baselines"] = [torch.zeros_like(r) for r in traj_data["returns"]]
        else:
            traj_data["baselines"] = [torch.tensor(self._features(o).dot(self._coeffs), dtype=torch.get_default_dtype())
                                      for o in traj_data["observations"]]
        return traj_data

    def update(self, traj_data):
        """
        Fits the linear baseline model with the provided paths via damped least squares

        Args:
            traj_data (dict): list of paths

        """
        featmat = np.concatenate([self._features(traj_data["observations"][i])
                                  for i in range(len(traj_data["observations"]))], axis=0)
        target = torch.cat(traj_data["discounted_returns"]).numpy()
        reg_coeff = self._reg_coeff
        for i in range(10):
            try:
                self._coeffs = np.linalg.lstsq(
                    featmat.T.dot(featmat) + reg_coeff * np.identity(featmat.shape[1]),
                    featmat.T.dot(target),
                    rcond=-1
                )[0]
                if not np.any(np.isnan(self._coeffs)):
                    break
            except Exception as exc:
                if i < 9:
                    warnings.warn(f'LSTSQ did not converge with reg coef {self._reg_coeff}', RuntimeWarning)
                    reg_coeff *= 10
                else:
                    raise exc

    def _features(self, obs):
        obs = np.clip(obs, -10, 10)
        path_length = len(obs)
        time_step = np.arange(path_length).reshape(-1, 1) / 100.0
        return np.concatenate([obs, obs ** 2, time_step, time_step ** 2, time_step ** 3, np.ones((path_length, 1))],
                              axis=1)

    def set_param_values(self, value):
        """
        Sets the parameter values of the baseline object

        Args:
            value: numpy array of linear_regression coefficients

        """
        self._coeffs = value

    def get_param_values(self):
        """
        Returns the parameter values of the baseline object

        Returns:
            numpy array of linear_regression coefficients

        """
        return self._coeffs


class ZeroBaseline(object):
    def __init__(self, *args, **kwargs):
        super(object, self).__init__()

    def predict(self, traj_data):
        traj_data["baselines"] = list(map(torch.zeros_like, traj_data["discounted_returns"]))
        return traj_data

    def update(self, traj_data):
        return traj_data

