import numpy as np
import copy
import os
import jax
import equinox as eqx

from gymnasium.utils import save_video

class Sampler:
    """Sampler to collect and evaluate agent behavior."""

    def __init__(self, env, episode_limit=1000, init_random_samples=1000,
                 visual_env=False):
        """
        Parameters
        ----------
        env : Environment to run the agent.
        episode_limit : Maximum number of timesteps per trajectory, default is 1000.
        init_random_samples : Number of initial timesteps to execute random behavior, default is 1000.
        visual_env : Environment returns visual observations, default is False.
        """
        self._env = env
        self._eval_env = self._env
        self._visual_env = visual_env
        self._el = episode_limit
        self._nr = init_random_samples
        self._tc = 0
        self._ct = 0

        self._ob = None
        self._reset = True

    def _handle_ob(self, ob):
        if self._visual_env:
            return ob['obs']
        return ob

    def sample_steps(self, policy, noise_stddev, n_steps=1, dac_augmentation=False):
        """Collect a number of transition steps with policy."""
        obs, nobs, acts, rews, dones = [], [], [], [], []
        if self._visual_env:
            visual_obs = []
        for i in range(n_steps):
            if self._reset or self._ct >= self._el:
                self._ct = 0
                self._reset = False
                # self._ob = self._handle_ob(self._env.reset())
                self._ob = self._env.reset()
            if self._tc < self._nr:
                act = self._env.action_space.sample()
            else:
                act = np.array(policy.get_action(np.expand_dims(self._ob.astype('float32'), axis=0), noise_stddev))[0]
            obs.append(self._ob)
            acts.append(act)
            self._ob, rew, self._reset, info = self._env.step(act)
            if self._visual_env:
                # visual_obs.append(self._ob['im'])
                visual_obs.append(self._env.get_ims())
            # self._ob = self._handle_ob(self._ob)
            nobs.append(self._ob)
            rews.append(rew)
            dones.append(self._reset)
            self._ct += 1
            self._tc += 1
            if dac_augmentation:
                if self._reset:
                    nobs[-1] = self._env.absorbing_state
                    dones[-1] = False
                    obs.append(self._env.absorbing_state)
                    nobs.append(self._env.absorbing_state)
                    acts.append(np.zeros(self._env.action_space.shape))
                    rews.append(0.0)
                    dones.append(False)
                    self._ct += 1
        self._reset = True
        out = {'obs': np.stack(obs), 'nobs': np.stack(nobs), 'act': np.stack(acts),
               'rew': np.array(rews), 'don': np.array(dones), 'n': n_steps}
        if self._visual_env:
            out['ims'] = np.stack(visual_obs)
        return out
    
    def sample_trajectory(self, policy, replay_buffer, get_ims=True):
        """Collect a full trajectory with policy."""
        obs, nobs, acts, rews, dones = [], [], [], [], []
        if self._visual_env and get_ims:
            visual_obs = []
        ret = 0
        ct = 0
        done = False
        ob = self._env.reset()
        if self._visual_env and get_ims:
            _ = self._env.get_ims()
        while not done and ct < self._el:
            # first nr with random policy, then learner
            if self._tc < self._nr:
                act = self._env.action_space.sample()
            else:
                act = policy.sample_actions(ob)
            obs.append(ob)
            acts.append(act)
            nob, rew, done, info = self._env.step(act)
            if self._visual_env and get_ims:
                visual_obs.append(self._env.get_ims())
            nobs.append(nob)
            rews.append(rew)
            dones.append(done)
            ret += rew
            ct += 1
            self._tc += 1
            replay_buffer.insert({'observations': ob, 
                              'actions': act, 
                              'rewards': rew, 
                              'masks': 1.0 - float(done), 'dones': float(done), 'next_observations': nob})
            nob = ob
            
        if self._eval_env.spec.id.split("-")[0] == 'SweepToTop':
            rews = self._eval_env.score_on_end_of_traj()
        out = {'observations': np.stack(obs), 'next_observations': np.stack(nobs), 'actions': np.stack(acts),
               'rewards': np.array(rews), 'dones': np.array(dones), 'n': ct, 'masks': 1 - np.array(dones)}
        if self._visual_env and get_ims:
            out['ims'] = np.stack(visual_obs)
        return out

    def sample_test_trajectories(self, policy, n=5, visualize=True, get_ims=False):
        """Collect multiple trajectories with policy keeping track of trajectory-specific statistics."""
        obs, nobs, acts, rews, dones, rets, ids = [], [], [], [], [], [], []
        images = []
        if policy is None:
            print('WARNING: running random policy')
        for i in range(n):
            ret = 0
            ct = 0
            done = False
            frames = []
            ob = self._eval_env.reset()
            while not done and ct < self._el:
                if policy is not None:
                    act = policy.sample_actions(ob, temperature=0.0)
                else:
                    act = self._eval_env.action_space.sample()
                obs.append(ob)
                acts.append(act)
                ob, rew, done, info = self._eval_env.step(act)
                if visualize:
                    img = self._eval_env.render(mode='rgb_array')
                    frames.append(img)
                    #images.append(img)
                nobs.append(ob)
                rews.append(rew)
                dones.append(done)
                ids.append(i)
                ret += rew
                ct += 1
            if self._eval_env.spec is not None and self._eval_env.spec.id.split("-")[0] == 'SweepToTop':
                ret = self._eval_env.score_on_end_of_traj()
            rets.append(ret)
        save_video.save_video(frames, video_folder=".", fps=10)
        out = {'observations': np.stack(obs), 'next_observations': np.stack(nobs), 'actions': np.stack(acts),
               'rewards': np.array(rews), 'dones': np.array(dones), 'n': ct, 'ret': rets}
        if visualize and get_ims:
            out['ims'] = np.stack(images)
        return out

    def evaluate(self, policy, n, viz=False, log=True, get_ims=False):
        """Collect multiple trajectories with policy and log trajectory-specific statistics."""
        traj = self.sample_test_trajectories(policy, n, visualize=viz, get_ims=get_ims)
        return self.log_trajectory_statistics(traj['ret'], log), traj

    def log_trajectory_statistics(self, trajectory_rewards, log=True):
        """Log and return trajectory statistics."""
        out = {}
        out['n'] = len(trajectory_rewards)
        out['single_rew'] = trajectory_rewards[-1]
        out['mean'] = np.mean(trajectory_rewards)
        out['max'] = np.max(trajectory_rewards)
        out['min'] = np.min(trajectory_rewards)
        out['std'] = np.std(trajectory_rewards)
        if log:
            print('Number of completed trajectories - {}'.format(out['n']))
            print('Latest trajectories mean reward - {}'.format(out['mean']))
            print('Latest trajectories max reward - {}'.format(out['max']))
            print('Latest trajectories min reward - {}'.format(out['min']))
            print('Latest trajectories std reward - {}'.format(out['std']))
            print(f"Latest trajectories reward - {out['single_rew']}")
        return out
    
    def sample_test_steps(self, policy, noise_stddev, n_steps=5, visualize=False, only_visual_data=False, get_ims=False):
        """Collect multiple trajectories with policy keeping track of trajectory-specific statistics."""
        obs, nobs, acts, rews, dones, rets, ids = [], [], [], [], [], [], []
        if policy is None:
            print('WARNING: running random policy')
        if self._visual_env and get_ims:
            visual_obs = []
        ret = 0
        ct = 0
        done = True
        nep = 0
        for i in range(n_steps):
            if done or ct >= self._el:
                ret = 0
                ct = 0
                done = False
                # ob = self._handle_ob(self._eval_env.reset())
                ob = self._eval_env.reset()
                if self._visual_env and get_ims:
                    os.environ['CUDA_VISIBLE_DEVICES']='4'
                    _ = self._eval_env.get_ims()
                    os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
            if policy is not None:
                act = np.array(policy.get_action(np.expand_dims(ob.astype('float32'), axis=0), noise_stddev))[0]
            else:
                act = self._eval_env.action_space.sample()
            obs.append(ob)
            acts.append(act)
            ob, rew, done, info = self._eval_env.step(act)
            if visualize:
                self._eval_env.render()
            if self._visual_env and get_ims:
                # visual_obs.append(ob['im'])
                visual_obs.append(self._eval_env.get_ims())
            # ob = self._handle_ob(ob)
            nobs.append(ob)
            rews.append(rew)
            dones.append(done)
            ids.append(nep)
            ret += rew
            ct += 1
            if done or ct >= self._el:
                print(i+1, ct, ret, done)
                rets.append(ret)
                nep += 1
        out = {'obs': np.stack(obs), 'nobs': np.stack(nobs), 'act': np.stack(acts),
               'rew': np.array(rews), 'don': np.array(dones), 'n': n_steps, 'ret': rets,
               'ids': np.array(ids)}
        if self._visual_env and get_ims:
            out['ims'] = np.stack(visual_obs)
        return out


class NoisySampler(Sampler):
    """Sampler to collect and evaluate perturbed agent behavior."""

    def __init__(self, env, episode_limit=1000, init_random_samples=1000,
                 visual_env=False):
        super(NoisySampler, self).__init__(env, episode_limit=episode_limit,
                                           init_random_samples=init_random_samples,
                                           visual_env=visual_env)

    def sample_test_trajectories(self, policy, noise_stddev, n=5, visualize=False, post_noise=0.0, get_ims=False):
        """Collect multiple trajectories with perturbed policy keeping track of trajectory-specific statistics."""
        obs, nobs, acts, rews, dones, rets, ids = [], [], [], [], [], [], []
        if self._visual_env and get_ims:
            visual_obs = []
        for i in range(n):
            ret = 0
            ct = 0
            done = False
            # ob = self._handle_ob(self._eval_env.reset())
            ob = self._eval_env.reset()
            while not done and ct < self._el:
                noise = np.random.randn() * post_noise
                act = np.clip(np.array(policy.get_action(np.expand_dims(ob.astype('float32'), axis=0), noise_stddev))[0] + noise, -1., 1.)
                obs.append(ob)
                acts.append(act)
                ob, rew, done, info = self._eval_env.step(act)
                if visualize:
                    self._eval_env.render()
                if self._visual_env and get_ims:
                    # visual_obs.append(ob['im'])
                    visual_obs.append(self._eval_env.get_ims())
                # ob = self._handle_ob(ob)
                nobs.append(ob)
                rews.append(rew)
                dones.append(done)
                ids.append(i)
                ret += rew
                ct += 1
            rets.append(ret)
        out = {'obs': np.stack(obs), 'nobs': np.stack(nobs), 'act': np.stack(acts),
               'rew': np.array(rews), 'don': np.array(dones), 'n': ct, 'ret': rets,
               'ids': np.array(ids)}
        if self._visual_env and get_ims:
            out['ims'] = np.stack(visual_obs)
        return out