# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Two types of filters which can be applied to policy output sequences.

1. Simple exponential filter
2. Butterworth filter - lowpass or bandpass

The implementation of the butterworth filter follows scipy's lfilter
https://github.com/scipy/scipy/blob/v1.2.1/scipy/signal/signaltools.py

We re-implement the logic in order to explicitly manage the y states

The filter implements::
       a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
                             - a[1]*y[n-1] - ... - a[N]*y[n-N]

We assume M == N.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import collections
# from absl import logging
import numpy as np
from scipy.signal import butter

ACTION_FILTER_ORDER = 2
ACTION_FILTER_LOW_CUT = 0.0
ACTION_FILTER_HIGH_CUT = 4.0

class ActionFilter(object):
  """Implements a generic lowpass or bandpass action filter."""

  def __init__(self, a, b, order, num_joints, ftype='lowpass'):
    """Initializes filter.

    Either one per joint or same for all joints.

    Args:
      a: filter output history coefficients
      b: filter input coefficients
      order: filter order
      num_joints: robot DOF
      ftype: filter type. 'lowpass' or 'bandpass'
    """
    self.num_joints = num_joints
    if isinstance(a, list):
      self.a = a
      self.b = b
    else:
      self.a = [a]
      self.b = [b]

    # Either a set of parameters per joint must be specified as a list
    # Or one filter is applied to every joint
    if not ((len(self.a) == len(self.b) == num_joints) or (
        len(self.a) == len(self.b) == 1)):
      raise ValueError('Incorrect number of filter values specified')

    # Normalize by a[0]
    for i in range(len(self.a)):
      self.b[i] /= self.a[i][0]
      self.a[i] /= self.a[i][0]

    # Convert single filter to same format as filter per joint
    if len(self.a) == 1:
      self.a *= num_joints
      self.b *= num_joints
    self.a = np.stack(self.a)
    self.b = np.stack(self.b)

    if ftype == 'bandpass':
      assert len(self.b[0]) == len(self.a[0]) == 2 * order + 1
      self.hist_len = 2 * order
    elif ftype == 'lowpass':
      assert len(self.b[0]) == len(self.a[0]) == order + 1
      self.hist_len = order
    else:
      raise ValueError('%s filter type not supported' % (ftype))

    # logging.info('Filter shapes: a: %s, b: %s', self.a.shape, self.b.shape)
    # logging.info('Filter type:%s', ftype)

    self.yhist = collections.deque(maxlen=self.hist_len)
    self.xhist = collections.deque(maxlen=self.hist_len)
    self.reset()

  def reset(self):
    """Resets the history buffers to 0."""
    self.yhist.clear()
    self.xhist.clear()
    for _ in range(self.hist_len):
      self.yhist.appendleft(np.zeros((self.num_joints, 1)))
      self.xhist.appendleft(np.zeros((self.num_joints, 1)))

  def filter(self, x):
    """Returns filtered x."""
    xs = np.concatenate(list(self.xhist), axis=-1)
    ys = np.concatenate(list(self.yhist), axis=-1)
    y = np.multiply(x, self.b[:, 0]) + np.sum(
        np.multiply(xs, self.b[:, 1:]), axis=-1) - np.sum(
            np.multiply(ys, self.a[:, 1:]), axis=-1)
    # import pdb; pdb.set_trace()
    self.xhist.appendleft(x.reshape((self.num_joints, 1)).copy())
    self.yhist.appendleft(y.reshape((self.num_joints, 1)).copy())
    return y

  def init_history(self, x):
    x = np.expand_dims(x, axis=-1)
    for i in range(self.hist_len):
      self.xhist[i] = x
      self.yhist[i] = x

class ActionFilterButter(ActionFilter):
  """Butterworth filter."""

  def __init__(self,
               lowcut=None,
               highcut=None,
               sampling_rate=None,
               order=ACTION_FILTER_ORDER,
               num_joints=None):
    """Initializes a butterworth filter.

    Either one per joint or same for all joints.

    Args:
      lowcut: list of strings defining the low cutoff frequencies.
        The list must contain either 1 element (same filter for all joints)
        or num_joints elements
        0 for lowpass, > 0 for bandpass. Either all values must be 0
        or all > 0
      highcut: list of strings defining the high cutoff frequencies.
        The list must contain either 1 element (same filter for all joints)
        or num_joints elements
        All must be > 0
      sampling_rate: frequency of samples in Hz
      order: filter order
      num_joints: robot DOF
    """
    # import pdb; pdb.set_trace()
    self.lowcut = ([float(x) for x in lowcut]
                   if lowcut is not None else [ACTION_FILTER_LOW_CUT])
    self.highcut = ([float(x) for x in highcut]
                    if highcut is not None else [ACTION_FILTER_HIGH_CUT])
    if len(self.lowcut) != len(self.highcut):
      raise ValueError('Number of lowcut and highcut filter values should '
                       'be the same')

    if sampling_rate is None:
      raise ValueError('sampling_rate should be provided.')

    if num_joints is None:
      raise ValueError('num_joints should be provided.')

    if np.any(self.lowcut):
      if not np.all(self.lowcut):
        raise ValueError('All the filters must be of the same type: '
                         'lowpass or bandpass')
      self.ftype = 'bandpass'
    else:
      self.ftype = 'lowpass'

    a_coeffs = []
    b_coeffs = []
    for i, (l, h) in enumerate(zip(self.lowcut, self.highcut)):
      if h <= 0.0:
        raise ValueError('Highcut must be > 0')

      b, a = self.butter_filter(l, h, sampling_rate, order)
      # logging.info(
      #     'Butterworth filter: joint: %d, lowcut: %f, highcut: %f, '
      #     'sampling rate: %d, order: %d, num joints: %d', i, l, h,
      #     sampling_rate, order, num_joints)
      b_coeffs.append(b)
      a_coeffs.append(a)

    super(ActionFilterButter, self).__init__(
        a_coeffs, b_coeffs, order, num_joints, self.ftype)

  def butter_filter(self, lowcut, highcut, fs, order=5):
    """Returns the coefficients of a butterworth filter.

    If lowcut = 0, the function returns the coefficients of a low pass filter.
    Otherwise, the coefficients of a band pass filter are returned.
    Highcut should be > 0

    Args:
      lowcut: low cutoff frequency
      highcut: high cutoff frequency
      fs: sampling rate
      order: filter order
    Return:
      b, a: parameters of a butterworth filter
    """
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    if low:
      b, a = butter(order, [low, high], btype='band')
    else:
      b, a = butter(order, [high], btype='low')
    return b, a
  
  def reset_by_ids(self, action_ids):
    """Resets the history buffers to 0."""
    x_hist_backup = self.xhist.copy()
    y_hist_backup = self.yhist.copy()

    self.yhist.clear()
    self.xhist.clear()
    for _ in range(self.hist_len):
      x, y = x_hist_backup.popleft(), y_hist_backup.popleft()
      x[action_ids] = 0.
      y[action_ids] = 0.
      self.yhist.append(y)
      self.xhist.append(x)

class ActionFilterButterTorch(ActionFilterButter):
  """ Utilizes pytorch for filtering. """
  def __init__(self,
               lowcut=None,
               highcut=None,
               sampling_rate=None,
               order=ACTION_FILTER_ORDER,
               num_joints=None,
               device='cpu'):
    super(ActionFilterButterTorch, self).__init__(
      lowcut, highcut, sampling_rate, order, num_joints
    )
    self.device = device
    self.a_torch = torch.tensor(self.a, dtype=torch.float32, device=self.device)
    self.b_torch = torch.tensor(self.b, dtype=torch.float32, device=self.device)
    self.xhist_torch = torch.zeros((self.hist_len, self.num_joints), dtype=torch.float32, device=self.device)
    self.yhist_torch = torch.zeros((self.hist_len, self.num_joints), dtype=torch.float32, device=self.device)
    # self.reset_torch()

  def filter_old(self, x):
    return super(ActionFilterButterTorch, self).filter(x)
  
  def reset_old(self, action_ids):
    return super(ActionFilterButterTorch, self).reset_by_ids(action_ids)
  
  def filter(self, x):
    """Returns filtered x. """
    # self.a_torch shape: (num_joints, order+1)
    # self.b_torch shape: (num_joints, order+1)
    # x shape: (num_joints, 1)
    # self.xhist_torch shape: (hist_len, num_joints)
    # self.yhist_torch shape: (hist_len, num_joints)
    y = x * self.b_torch[:, 0] \
          + torch.sum(self.xhist_torch.T * self.b_torch[:, 1:], dim=1) \
            - torch.sum(self.yhist_torch.T * self.a_torch[:, 1:], dim=1)
    
    self.xhist_torch = torch.cat([x.unsqueeze(0), self.xhist_torch[:-1]], dim=0)
    self.yhist_torch = torch.cat([y.unsqueeze(0), self.yhist_torch[:-1]], dim=0)
    return y
  
  def reset_hist(self, action_ids):
    self.xhist_torch[:, action_ids] = 0.
    self.yhist_torch[:, action_ids] = 0.

class ActionFilterExp(ActionFilter):
  """Filter by way of simple exponential smoothing.

  y = alpha * x + (1 - alpha) * previous_y
  """

  def __init__(self, alpha, num_joints):
    """Initialize the filter.

    Args:
      alpha: list of strings defining the alphas.
        The list must contain either 1 element (same filter for all joints)
        or num_joints elements
        0 < alpha <= 1
      num_joints: robot DOF
    """
    self.alphas = [float(x) for x in alpha]
    # logging.info('Exponential filter: alpha: %d', self.alphas)

    a_coeffs = []
    b_coeffs = []
    for a in self.alphas:
      a_coeffs.append(np.asarray([1., a - 1.]))
      b_coeffs.append(np.asarray([a, 0]))

    order = 1
    self.ftype = 'lowpass'

    super(ActionFilterExp, self).__init__(
        a_coeffs, b_coeffs, order, num_joints, self.ftype)