#!/usr/bin/env python3
import numpy as np
import torch
import torch.nn.functional as F
import os
import time
from video import VideoRecorder
from logger import Logger
from DST.utils import ReplayBuffer, show_sparsity
import utils
import sys
import dmc2gym
import hydra
import gym
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)




def make_env(cfg):
    """Helper function to create dm_control environment"""

    if cfg.env_type == 'dm_control':
        if cfg.env == 'ball_in_cup_catch':
            domain_name = 'ball_in_cup'
            task_name = 'catch'
        else:
            domain_name = cfg.env.split('_')[0]
            task_name = '_'.join(cfg.env.split('_')[1:])

        env = dmc2gym.make(domain_name=domain_name,
                           task_name=task_name,
                           seed=cfg.seed,
                           visualize_reward=True)
    elif cfg.env_type == 'gym':
        env = gym.make(cfg.env)

    elif cfg.env_type == 'deepmimic':
        env = gym.make(cfg.env)

    else:
        print('choose correct env')

    env.seed(cfg.seed)
    env.action_scale_high = env.action_space.high.max()
    env.action_scale_low = env.action_space.high.min()
    return env

def make_agent(obs_dim, action_dim, action_range, cfg):
    cfg.obs_dim = obs_dim
    cfg.action_dim = action_dim
    cfg.action_range = action_range
    return hydra.utils.instantiate(cfg)

class Workspace(object):
    def __init__(self, cfg):

        self.cfg = cfg
        # set workdir
        self.set_work_dir()
        self.set_logger()
        # set seed
        self.set_seed()

        self.device = torch.device(cfg.device)
        self.env = make_env(cfg)

        self.agent = make_agent(self.env.observation_space.shape[0],
                                self.env.action_space.shape[0],
                                [float(self.env.action_space.low.min()), float(self.env.action_space.high.max())],
                                self.cfg.agent)

        self.replay_buffer = ReplayBuffer(self.env.observation_space.shape[0],
                                          self.env.action_space.shape[0],
                                          int(cfg.replay_buffer_capacity),)

        self.video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None)
        self.step = 0


    def set_agent(self):
        self.agent = make_agent(self.env.observation_space.shape[0],
                                self.env.action_space.shape[0],
                                [float(self.env.action_space.low.min()), float(self.env.action_space.high.max())],
                                self.cfg.agent)


    def set_logger(self):
        self.logger = Logger(self.work_dir,
                             save_tb=self.cfg.log_save_tb,
                             log_frequency=self.cfg.log_frequency,
                             agent=self.cfg.agent_name)

    def set_work_dir(self):
        self.work_dir = os.getcwd()
        self.work_dir = self.work_dir + f'/algo={self.cfg["agent_name"]},pruning_algo={self.cfg.agent.pruning_algo},env={self.cfg["env"]},' \
                                        f'env_type={self.cfg["env_type"]},seed={self.cfg["seed"]},' \
                                        f'bs={self.cfg.agent.batch_size},h_dim={self.cfg.diag_gaussian_actor.hidden_dim},' \
                                        f'h_depth={self.cfg.hidden_depth},kr={self.cfg.keep_ratio},lr={self.cfg.lr}'
        print(f'workspace: {self.work_dir}')

    def set_seed(self):
        utils.set_seed_everywhere(self.cfg.seed)


    def reset_episodic_storage(self):
        self.storage = {'observations': [], 'actions': [], 'rewards': [], 'terminals': [], 'next_observations': [], 'episodic_returns': [], 'success': []}



    def evaluate(self):
        self.reset_episodic_storage()
        average_episode_reward = 0
        average_episode_len = 0
        average_success_rate = 0

        for episode in range(self.cfg.num_eval_episodes):
            obs = self.env.reset()
            episode_len = 0
            self.agent.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0

            while not (done or episode_len >= self.env._max_episode_steps):
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, sample=False)
                self.storage['observations'].append(obs.astype(np.float32))
                self.storage['actions'].append(action.astype(np.float32))

                obs, reward, done, info = self.env.step(action*self.env.action_scale_high)
                if (episode_len + 1) == self.env._max_episode_steps:
                    done = True
                    print(f'reset episode {episode}')

                self.storage['next_observations'].append(obs.astype(np.float32))
                self.storage['rewards'].append(reward)
                self.storage['terminals'].append(int(done))

                self.video_recorder.record(self.env)
                episode_reward += reward
                average_episode_len += 1
                episode_len += 1

            self.storage['episodic_returns'].append(episode_reward)
            average_episode_reward += episode_reward

            self.video_recorder.save(f'{self.step}.mp4')
        average_episode_reward /= self.cfg.num_eval_episodes
        average_episode_len /= self.cfg.num_eval_episodes
        if self.cfg.env_type == 'metaworld':
            average_success_rate /= self.cfg.num_eval_episodes
            print(f'Average success rate {average_success_rate}')
            self.logger.log('eval/success_rate', average_success_rate, self.step)
        self.logger.log('eval/episode_reward', average_episode_reward, self.step)
        self.logger.log('eval/episode_len', average_episode_len, self.step)

        self.logger.dump(self.step)
        return average_episode_reward




    def run(self):

        print("Training a sparse actor network:", show_sparsity(self.agent.actor.state_dict()))
        print("Training a sparse critic network:", show_sparsity(self.agent.critic.state_dict()))

        best_ep_ret = 0
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        activate_eval = False
        info = {}
        info['success'] = False
        while self.step < self.cfg.num_train_steps:
            if done:

                if self.step > 0:
                    self.logger.log('train/duration', time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if activate_eval:
                    self.logger.log('eval/episode', episode, self.step)
                    avg_eval_ret = self.evaluate()
                    self.agent.save(self.work_dir, step='final')  # saves the last evaluation
                    if avg_eval_ret > best_ep_ret:
                        self.agent.save(self.work_dir, step='best')
                        best_ep_ret = avg_eval_ret

                self.logger.log('train/episode_reward', episode_reward, self.step)

                obs = self.env.reset()
                self.agent.reset()
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1

                self.logger.log('train/episode', episode, self.step)

            # sample action for data collection
            if self.step < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, sample=True)

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.agent.update_rigl(self.replay_buffer, self.logger, self.step, self.cfg.batch_size)

            next_obs, reward, done, info = self.env.step(action*self.env.action_scale_high)
            if episode_step + 1 == self.env._max_episode_steps:
                done = True
            done = float(done)
            done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward
            self.replay_buffer.add(obs, action, next_obs, reward, done, action, episode_step >= self.env._max_episode_steps)
            obs = next_obs
            episode_step += 1
            self.step += 1

            # Applies for rlx2
            if self.cfg.use_dynamic_buffer and (self.step + 1) % self.cfg.buffer_adjustment_interval == 0:
                if self.replay_buffer.size == self.replay_buffer.max_size:
                    ind = (self.replay_buffer.ptr + np.arange(8 * self.cfg.agent.batch_size)) % self.replay_buffer.max_size
                else:
                    ind = (self.replay_buffer.left_ptr + np.arange(8 * self.cfg.agent.batch_size)) % self.replay_buffer.max_size
                batch_state = torch.FloatTensor(self.replay_buffer.state[ind]).to(self.cfg.device)
                batch_action_mean = torch.FloatTensor(self.replay_buffer.action_mean[ind]).to(self.cfg.device)
                with torch.no_grad():
                    current_action = self.agent.actor(batch_state).mean
                    distance = F.mse_loss(current_action, batch_action_mean) / 2
                if distance > self.cfg.buffer_threshold and self.replay_buffer.size > self.cfg.buffer_min_size:
                    self.replay_buffer.shrink()


            if self.step % self.cfg.eval_frequency == 0:
                activate_eval = True

        # save the recorder only on the last eval
        self.video_recorder = VideoRecorder(self.work_dir)
        self.agent.save(self.work_dir, step='final')



os.environ["HYDRA_FULL_ERROR"] = "1"
@hydra.main(config_path='./config', config_name='train_pruned')
def main(cfg):
    from train_pruned_rlx2 import Workspace as W
    if cfg.agent.pruning_algo not in ['rlx2', 'rigl']:
        sys.exit("select 'agent.pruning_algo'='rlx2' or 'agent.pruning_algo'='rigl' ")
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
   main()

