#!/usr/bin/env python3
import numpy as np
import torch
import os
import re
import time

from envs.obs_transforms import (OBS_TRANSFORMS, RECORD_TRANSFORMS,
                                 obs_transform_default,
                                 record_transform_default)
from video import VideoRecorderWithStates
from logger import Logger
from replay_buffer import ReplayBuffer
from reward_module import UpdatingRewardModule
import utils

import hydra
import wandb

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg
        wandb.init(
            project='incremental_primitives',
            group=cfg.group,
            tags=cfg.tags,
            monitor_gym=True,
        )
        # Add our experiment to our run name
        wandb.run.name = cfg.experiment + '-' + wandb.run.name
        wandb.run.save()

        wandb.config.update({
            'env': cfg.env,
            'num_train_steps': cfg.num_train_steps,
            'saved_latent_per_skill': cfg.saved_latent_per_skill,
            'sparse': cfg.sparse_reward,
            'path': self.work_dir
        })

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency)

        utils.set_seed_everywhere(cfg.seed)

        self.device = torch.device(cfg.device)
        self.env = utils.make_env(cfg)

        _, env_name = cfg.env.split('.')
        if env_name == 'Ant-block':
            self._next_block_to_remove = -1
        else:
            self._next_block_to_remove = 0
        # 40 is the number of blocks.
        self._block_removal_steps = (int(cfg.num_train_steps) //
                                     (int(cfg.num_blocks) //
                                      int(cfg.blocks_to_remove_at_once)))

        env_prefix = re.split('-|_', env_name)[0]  # match either
        self.obs_transform = OBS_TRANSFORMS.get(
            env_prefix,
            obs_transform_default
        )
        sample_transformed_obs = self.obs_transform(self.env, self.env.reset())

        if isinstance(self.cfg.num_steps_per_skill, (int, float)):
            steps_per_skill = int(self.cfg.num_steps_per_skill)
            num_total_skills = int(self.cfg.num_train_steps //
                                   steps_per_skill)
            self.cfg.num_steps_per_skill = [steps_per_skill for _ in
                                            range(num_total_skills)]

            self.cfg.total_skills = num_total_skills
        else:
            self.cfg.total_skills = len(self.cfg.num_steps_per_skill)
        self.cfg.transformed_obs_shape = sample_transformed_obs.shape

        self.env._max_episode_steps = cfg.max_episode_steps
        self.max_episode_steps = cfg.max_episode_steps
        self.collected_trajectories = int(self.cfg.collected_trajectories)

        cfg.agent.obs_dim = self.env.observation_space.shape[0]
        cfg.agent.t_obs_dim = sample_transformed_obs.shape[0]
        cfg.agent.action_dim = self.env.action_space.shape[0]
        cfg.agent.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]
        self.agent = hydra.utils.instantiate(cfg.agent)

        self.replay_buffer = ReplayBuffer(self.env.observation_space.shape,
                                          self.cfg.transformed_obs_shape,
                                          self.env.action_space.shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device)

        self.reward_module = UpdatingRewardModule(
            self.cfg.transformed_obs_shape,
            self.max_episode_steps,
            int(self.cfg.saved_latent_per_skill),
            int(self.cfg.total_skills),
            max_running_obses=int(self.cfg.max_running_obses),
            slow_update_coeff=int(self.cfg.slow_update_coeff),
            device=self.device,
            alpha=self.cfg.alpha,
            beta=self.cfg.beta,
            use_t_obs=self.cfg.use_t_obs,
            use_t_vel=self.cfg.use_t_vel,
            use_timesteps=self.cfg.use_timesteps,
            use_entropy=self.cfg.use_entropy
        )
        self.agent.register_reward_module(self.reward_module)
        self.reward_module.register_logger(self.logger)

        self.rec_transform = RECORD_TRANSFORMS.get(
            env_prefix,
            record_transform_default
        )
        sample_record_obs = self.rec_transform(self.env, self.env.reset())
        self.video_recorder = VideoRecorderWithStates(
            self.work_dir if cfg.save_video else None,
            fps=90,
            obs_transforms=self.rec_transform,
            num_transforms=sample_record_obs.shape[0])
        self.step = 0
        self.episode = 0

        if cfg.load_saved:
            load_path = cfg.load_directory
            self.replay_buffer.load(path=load_path)
            self.reward_module.load_buffers(load_path)
            self.agent._load_agent(load_path)
            self.replay_buffer.purge_frac(frac=self.cfg.load_purge_frac)
            self.reward_module._saved_reward_call *= cfg.updates_per_step
            print(self.agent.current_skill_num)
            print('Successfully loaded previous model')

    def evaluate(self, record=True):
        average_episode_reward = 0
        self.env._max_episode_steps = self.cfg.max_test_episode_steps
        t_obs_stack = []
        t_obs_var_stack = []

        self.video_recorder.init(enabled=record)
        for skill_idx in range(self.agent.current_skill_num+1):
            self.video_recorder.init_new_skill()
            for episode in range(self.cfg.num_eval_episodes):
                t_obses = []
                self.video_recorder.record_blank(
                    f'Skill {skill_idx} Episode {episode}'
                )
                obs = self.env.reset()
                t_obs = self.transform_obs(obs)
                t_obs_stack.append(t_obs)
                t_obses.append(t_obs)
                prev_obs = obs.copy()
                prev_transformed_obs = t_obs.copy()
                self.agent.reset()
                done = False
                episode_reward = 0
                while not done:
                    t_obs_stack[-1] = t_obs
                    with utils.eval_mode(self.agent):
                        action = self.agent.act(obs, t_obs,
                                                prev_obs, prev_transformed_obs,
                                                sample=False,
                                                skill_index=skill_idx)
                    next_obs, reward, done, _ = self.env.step(action)
                    next_t_obs = self.transform_obs(obs)
                    t_obses.append(next_t_obs)
                    self.video_recorder.record(self.env, no_mujoco=True)
                    episode_reward += reward
                    prev_obs = obs.copy()
                    prev_transformed_obs = t_obs.copy()
                    obs = next_obs.copy()
                    t_obs = next_t_obs.copy()

                episode_reward = self.compute_episode_reward(
                    t_obses[1:], t_obses[:-1], range(len(t_obses) - 1))
                average_episode_reward += episode_reward
                # self.env.close()

            # Compute the variance between skill endpoints
            skill_ends = np.stack(t_obs_stack[-self.cfg.num_eval_episodes:])
            t_obs_var_stack.append(skill_ends.var(axis=0))

        self.video_recorder.save(f'step_{self.step}.mp4')
        average_episode_reward /= (self.cfg.num_eval_episodes *
                                   (self.agent.current_skill_num + 1))
        self.logger.log('eval/episode_reward', average_episode_reward,
                        self.step)
        wandb.log({
            'rewards/eval_reward': average_episode_reward,
        })
        self.logger.dump(self.step)
        self.env._max_episode_steps = self.cfg.max_episode_steps

        all_t_obs = np.stack(t_obs_stack)
        t_obs_var = all_t_obs.var(axis=0)
        t_obs_var_stack = np.stack(t_obs_var_stack)
        t_obs_mean_var = t_obs_var_stack.mean(axis=0)

        entropy = utils.pointwise_entropy(np.stack(t_obs_stack))
        cov_det, mean_cov_det = utils.compute_cov(all_t_obs,
                                                  len(t_obs_var_stack))
        logging_dict = {
            'epoch': self.episode,
            'environment_step': self.step,
            'rewards/final_state_entropy': entropy,
            'variance/covariance_det': cov_det,
            'inter_skill_variance/mean_covariance_det': mean_cov_det,
        }
        for i, (var, mean_var) in enumerate(zip(t_obs_var, t_obs_mean_var)):
            logging_dict[f'variance/var_{i}'] = var
            logging_dict[f'inter_skill_variance/mean_var_{i}'] = mean_var
        path = os.path.join(utils.make_dir(self.work_dir, 'eval_trajectories'),
                            f'{self.step}.npy')
        np.save(path, all_t_obs)
        wandb.log(logging_dict)

    def collect_trajectories(self):
        # Collect trajectories from the current, completed trajectory
        # and save it to the skill replay buffer.
        average_episode_reward = 0
        self.env._max_episode_steps = self.cfg.max_episode_steps
        for _ in range(self.collected_trajectories):
            obs = self.env.reset()
            transformed_obs = self.transform_obs(obs)
            prev_obs = obs.copy()
            prev_transformed_obs = transformed_obs.copy()
            self.agent.reset()
            done = False
            episode_reward = 0
            episode_step = 0
            while not done:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, transformed_obs,
                                            prev_obs, prev_transformed_obs,
                                            sample=False)
                transformed_obs = self.transform_obs(obs)
                next_obs, reward, done, _ = self.env.step(action)
                next_transformed_obs = self.transform_obs(next_obs)
                episode_reward += reward
                episode_step += 1
                self.reward_module.add_collected_trajectory(
                    transformed_obs, obs, next_transformed_obs, next_obs,
                    episode_step, done)
                prev_obs = obs.copy()
                prev_transformed_obs = transformed_obs.copy()
                obs = next_obs.copy()
                transformed_obs = next_transformed_obs.copy()
            average_episode_reward += episode_reward

    def compute_episode_reward(self, next_t_obs_stack,
                               t_obs_stack, timestep_stack, eval=True):
        args = []
        for arg in [next_t_obs_stack, t_obs_stack, timestep_stack]:
            args.append(torch.tensor(arg,
                                     device=self.device))
        episode_reward = self.reward_module.get_rewards(
            *args, eval=eval)
        return episode_reward.sum()

    def run(self):
        episode, episode_reward, done = 0, 0, True
        skill_steps = 0
        skill_now = 0
        start_time = time.time()
        saved_actor = False
        self.all_skill_rewards = []

        # Initialize the reward computation stacks
        t_obs_stack, next_t_obs_stack, timestep_stack = [], [], []
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # evaluate agent periodically
                if self.step > 0 and episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval/episode', episode, self.step)
                    self.evaluate(record=(episode % self.cfg.record_freq == 0))

                # Compute the episode reward
                if len(t_obs_stack):
                    episode_reward = self.compute_episode_reward(
                        next_t_obs_stack, t_obs_stack, timestep_stack)
                    # Reinitialize the reward computation stacks
                    t_obs_stack, next_t_obs_stack, timestep_stack = [], [], []
                    wandb.log({
                        'rewards/intrinsic_training_reward':
                        episode_reward.mean()
                    })
                self.logger.log('train/episode_reward', episode_reward,
                                self.step)

                if skill_steps >= (self.cfg.num_steps_per_skill[skill_now] // 3):
                    if not saved_actor:
                        # Save it a third of the way through.
                        print('Saving actor')
                        self.agent.save_actor_model(filename='last_actor.pt')
                        saved_actor = True

                # Gotta collect trajectories before we do any reset.
                if skill_steps >= self.cfg.num_steps_per_skill[skill_now]:
                    print("Collecting trajectories")
                    # First, collect completed trajectories from this skill.
                    self.collect_trajectories()
                    print("Done collecting trajectories")
                    saved_actor = False

                obs = self.env.reset()
                transformed_obs = self.transform_obs(obs)
                prev_obs = obs.copy()
                prev_transformed_obs = transformed_obs.copy()
                self.agent.reset()
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1

                self.episode = episode

                self.logger.log('train/episode', episode, self.step)

                if skill_steps >= self.cfg.num_steps_per_skill[skill_now]:
                    skill_now += 1
                    skill_steps = 0
                    # Save an intermediate plot
                    self.replay_buffer.purge_frac(self.cfg.train_purge_frac)
                    self.reward_module.add_new_skill(
                        self.cfg.num_steps_per_skill[skill_now])
                    self.agent.add_new_skill(
                        self.cfg.num_steps_per_skill[skill_now])
                    self.save_everything()
                    if hasattr(self.env, 'add_new_skill'):
                        self.env.add_new_skill()

            if ((self.step + 1) % self._block_removal_steps == 0) and \
               (self._next_block_to_remove < 0):
                # Remove a block from env by burying it down.
                print("---REMOVING BLOCKS---")
                for _ in range(int(self.cfg.blocks_to_remove_at_once)):
                    self.env.model.body_pos[self._next_block_to_remove, -1] = -4.
                    self._next_block_to_remove -= 1
            # sample action for data collection
            if skill_steps < self.cfg.num_seed_steps:
                action = self.env.action_space.sample()
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, transformed_obs,
                                            prev_obs, prev_transformed_obs,
                                            sample=True)
                # run training update
                for _ in range(self.cfg.updates_per_step):
                    self.agent.update(self.replay_buffer,
                                      self.logger,
                                      self.step)

            next_obs, reward, done, _ = self.env.step(action)
            transformed_next_obs = self.transform_obs(next_obs)

            episode_step += 1
            self.step += 1
            skill_steps += 1

            if episode_step > self.max_episode_steps:
                done = True

            # allow infinite bootstrap
            done = float(done)
            done_no_max = (0 if episode_step + 1 == self.env._max_episode_steps
                           else done)
            episode_reward += reward

            latent_reward = 0
            self.reward_module.add_current(transformed_obs, obs,
                                           transformed_next_obs, next_obs,
                                           episode_step, done)

            self.replay_buffer.add(obs, transformed_obs, action, latent_reward,
                                   next_obs, transformed_next_obs,
                                   prev_obs, prev_transformed_obs,
                                   episode_step, done, done_no_max)

            t_obs_stack.append(transformed_obs.copy())
            next_t_obs_stack.append(transformed_next_obs.copy())
            timestep_stack.append(episode_step)
            prev_obs = obs.copy()
            prev_transformed_obs = transformed_obs.copy()
            obs = next_obs.copy()
            transformed_obs = transformed_next_obs.copy()

        print('Done training')
        self.evaluate(True)
        self.collect_trajectories()
        self.reward_module.add_new_skill()
        self.agent.add_new_skill()
        self.save_everything()

    def transform_obs(self, obs):
        if self.obs_transform:
            return self.obs_transform(self.env, obs)
        return obs

    def save_everything(self):
        self.replay_buffer.save()
        self.reward_module.save_buffers()
        self.agent.save_agent()


@hydra.main(config_path='config', config_name='train')
def main(cfg):
    workspace = Workspace(cfg)
    workspace.run()


if __name__ == '__main__':
    main()
