from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

import numpy as np
from dotmap import DotMap
from dtaidistance import dtw_ndim
import cv2
from scipy.io import loadmat
class Agent:
    """An general class for RL agents.
    """

    def __init__(self, params):
        """Initializes an agent.

        Arguments:
            params: (DotMap) A DotMap of agent parameters.
                .env: (OpenAI gym environment) The environment for this agent.
                .noisy_actions: (bool) Indicates whether random Gaussian noise will
                    be added to the actions of this agent.
                .noise_stddev: (float) The standard deviation to be used for the
                    action noise if params.noisy_actions is True.
        """
        assert params.get("noisy_actions", False) is False
        self.env = params.env

        if isinstance(self.env, DotMap):
            raise ValueError("Environment must be provided to the agent at initialization.")

    def sample(self, horizon, policy,expert_traj, env_seed=None, record_fname=None,var_rec=False):
        """Samples a rollout from the agent.

        Arguments:
            horizon: (int) The length of the rollout to generate from the agent.
            policy: (policy) The policy that the agent will use for actions.
            record_fname: (str/None) The name of the file to which a recording of the rollout
                will be saved. If None, the rollout will not be recorded.

        Returns: (dict) A dictionary containing data from the rollout.
            The keys of the dictionary are 'obs', 'ac', and 'reward_sum'.
        """
        video_record = record_fname is not None
        times, rewards = [], []
        frames = []
        self.env.reset()
        one_more_step = False
        if env_seed==None:
            O, A, reward_sum, done = [self.env.reset_to_certain_state(expert_traj[0])], [], 0, False
        else:
            self.env.seed(env_seed)
            O, A, reward_sum, done = [self.env.reset()], [], 0, False
        policy.reset()
        for t in range(horizon):
            if video_record:
                frames.append(self.env.render(mode='rgb_array'))
            start = time.time()
            A.append(policy.act(O[t], t, O))
            times.append(time.time() - start)
            obs, reward, done, info = self.env.step(A[t])
            
            O.append(obs)
            reward_sum += reward
            rewards.append(reward)
            
            if done or one_more_step:
                break
            if policy.cur_progress >= expert_traj.shape[0]-1:
                one_more_step = True
        output = np.array(O).astype(np.double)
        target = expert_traj.numpy().astype(np.double)
        for s_i in range(output.shape[1]):
            tmp_max, tmp_min = policy.expert_max[s_i].item(), policy.expert_min[s_i].item()
            output[:, s_i] = (output[:, s_i]-tmp_min)/(tmp_max-tmp_min+1e-8)
            target[:, s_i] = (target[:, s_i]-tmp_min)/(tmp_max-tmp_min+1e-8)

        d = dtw_ndim.distance_fast(output, target)
        if video_record:
            size = (500,500)
            out = cv2.VideoWriter(record_fname,cv2.VideoWriter_fourcc(*'MJPG'),40, size)
            for i in range(len(frames)):
                rgb_img = cv2.cvtColor(frames[i], cv2.COLOR_RGB2BGR)
                out.write(rgb_img)
            out.release()
        print("Average action selection time: ", np.mean(times))
        print("Rollout length: ", len(A))
        return {
            "obs": np.array(O),
            "ac": np.array(A),
            "reward_sum": reward_sum,
            "rewards": np.array(rewards),
            "dtw_distance": d
        }
