import time
from typing import Dict, List, Any
import numpy as np
import gym

from augment.rl.augmentation_functions.augmentation_function import AugmentationFunction


class SwimmerReflect(AugmentationFunction):

    def __init__(self, sigma=0.1, k=2, **kwargs):
        super().__init__()
        self.sigma = sigma
        self.k = k


    def _augment(self,
                obs: np.ndarray,
                next_obs: np.ndarray,
                action: np.ndarray,
                reward: np.ndarray,
                done: np.ndarray,
                infos: List[Dict[str, Any]],
                delta = None,
                p=None
                ):

        # k = (obs.shape[-1]-2)//2
        k = 3
        obs[:,:k] *= -1
        obs[:,-k:] *= -1
        obs[:,k+1] *= -1

        next_obs[:,:k] *= -1
        next_obs[:,-k:] *= -1
        next_obs[:,k+1] *= -1


        action *= -1

        return obs, next_obs, action, reward, done, infos


SWIMMER_AUG_FUNCTIONS = {
    'reflect': SwimmerReflect,
}