# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from re import I
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import os
import sys
import time
import pickle as pkl
import itertools
import random
from collections import OrderedDict

from video import VideoRecorder
from logger import Logger
from replay_buffer import ReplayBuffer
import utils
from agent.encoder import Encoder_Decoder
from agent.encoder import VAE
from agent.actor import DiagGaussianW

from torch.utils.tensorboard import SummaryWriter
from torch import linalg as LA
from torch.autograd import Variable

import hydra
from omegaconf import OmegaConf
from omegaconf import DictConfig
from torchmetrics import MeanSquaredError
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

import metaworld


class Workspace(object):
    def __init__(self, cfg):
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency,
                             agent=cfg.name)

        self.writer = SummaryWriter(log_dir='tb')
        self.batch_size = cfg.agent.batch_size

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)

        self.goal_mode = cfg.goal_mode
        self.eval_on_unseen = True
        self.task = cfg.env
        self.single_task = False
        self.mt_task = cfg.single_task

        if cfg.env == 'ant_walk':
            env_sample_goal_shape = (3, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'reacher_easy':
            env_sample_goal_shape = (2, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'walker_stand' or cfg.env == 'walker_walk' \
            or cfg.env == 'walker_run':
            env_sample_goal_shape = (1, )
            self.benchmark = 'dm_suite'
        elif cfg.env == 'metaworld':
            # This is gonna be task id
            env_sample_goal_shape = (1, )
            self.benchmark = 'metaworld'
        else:
            raise ValueError('Invalid benchmark env.')

        self.all_envs = OrderedDict()
        if self.benchmark == 'dm_suite':
            if self.goal_mode == 'multi_goal':
                for i in range(8):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.all_envs.update({env_id: env})
                env_sample_key = list(self.all_envs.keys())[0]
                env_sample = self.all_envs[env_sample_key]
            else:
                env_sample = utils.make_env(cfg, task_id=6)
                self.env = env_sample

            if self.eval_on_unseen:
                self.eval_envs_unseen = OrderedDict()
                for i in range(5, 8):
                    env = utils.make_env(cfg, task_id=i)
                    env_id = i
                    self.eval_envs_unseen.update({env_id: env})
        else:
            # Metaworld
            if not self.single_task:
                ml2 = metaworld.ML2()
                for i, (name, env_cls) in enumerate(ml2.train_classes.items()):
                    env = env_cls()
                    task = random.choice([
                        task for task in ml2.train_tasks
                        if task.env_name == name
                    ])
                    env.set_task(task)
                    env_id = i
                    self.all_envs.update({env_id: env})
                if self.eval_on_unseen:
                    self.eval_envs_unseen = OrderedDict()
                    env = ml2.test_classes['reach-wall-v2']()
                    task = random.choice(ml2.test_tasks)
                    env.set_task(task)
                    env_id = 2
                    self.eval_envs_unseen.update({env_id: env})
            else:
                ml1 = metaworld.ML1(self.mt_task)
                env = ml1.train_classes[self.mt_task]()
                task = random.choice(ml1.train_tasks)
                env.set_task(task)
                self.all_envs.update({0: env})

        if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
            env_sample_key = list(self.all_envs.keys())[0]
            env_sample = self.all_envs[env_sample_key]
        else:
            env_sample = self.env

        obs_dim = env_sample.observation_space.shape[0]
        action_dim = env_sample.action_space.shape[0]
        cfg.agent.obs_dim = obs_dim
        cfg.agent.action_dim = action_dim
        cfg.agent.action_range = [
            float(env_sample.action_space.low.min()),
            float(env_sample.action_space.high.max())
        ]

        cfg.agent.env_id_dim = 0
        cfg.agent.goal_dim = env_sample_goal_shape[0]
        self.agent_expert = hydra.utils.instantiate(cfg.agent,
                                                    _recursive_=False)
        self.agent_random = hydra.utils.instantiate(cfg.agent,
                                                    _recursive_=False)

        self.model_dir = self.work_dir + '/agent_model'
        base_dir = self.work_dir.split('runs')[0] + 'runs/'

        expert_model_date = cfg.expert_model_date
        self.expert_dir = base_dir + str(expert_model_date) + '/'
        # if self.single_task:
        #     if cfg.single_task == 'reach-v2':
        #         self.expert_dir += 'metaworld_SACReacherExpert_env=metaworld,experiment=SACReacherExpert,goal_mode=single_goal/seed=1/'
        #     elif cfg.single_task == 'door-v2':
        #         self.expert_dir += 'metaworld_SACDoorCloseExpert_env=metaworld,experiment=SACDoorCloseExpert,goal_mode=single_goal,single_task=door-close-v2/seed=1/'
        # else:
        self.expert_dir += 'metaworld_SACReachDoorRWallV2_env=metaworld,experiment=SACReachDoorRWallV2/seed=2/'

        self.expert_dir = self.expert_dir + 'agent_model/'

        # Agent load in weights
        latest_step = utils.get_latest_file(self.expert_dir)
        self.agent_expert.load(self.expert_dir, latest_step)

        self.learn_w = cfg.representation.learn_w
        self.learn_phi = cfg.representation.learn_phi

        self.w_model = cfg.representation.w_model
        self.phi_model = cfg.representation.phi_model

        self.p_expert = cfg.representation.p_expert
        self.reg_lambda = cfg.representation.reg_lambda
        self.w_norm = cfg.representation.w_norm

        # This should be 14 when learning both phi and w
        # otherwise the hard-coded version assumes states only
        if self.benchmark == 'dm_suite':
            if self.goal_mode == 'multi_goal':

                if self.phi_model == 'mlp':
                    if self.task == 'reacher_easy':
                        self.state_features = 14
                        self.goal_dim = 2
                        # self.state_features = 10
                        self.latent_state_features = 14
                    elif self.task == 'ant_walk':
                        # including state obs, actions, next_states
                        self.state_features = 326
                        self.latent_state_features = 28
                        self.goal_dim = 3
                    else:
                        raise ValueError('Invalid Env.')
                elif self.phi_model == 'vae':
                    self.state_features = 14
            else:
                if self.task == 'walker_stand' or self.task == 'walker_walk' \
                    or self.task == 'walker_run':
                    self.state_features = 54
                    self.goal_dim = 1
                    self.latent_state_features = cfg.representation.latent_size
                else:
                    self.state_features = 4
                    self.latent_state_features = 4
        else:
            # Metaworld input dim
            self.state_features = 82
            self.goal_dim = 1
            self.latent_state_features = cfg.representation.latent_size

        # Random policy action noise
        self.policy_noise = 0.5
        self.noise_clip = 0.5
        self.min_action = cfg.agent.action_range[0]
        self.max_action = cfg.agent.action_range[1]

        #######################
        # Learn both phi and w
        ######################
        if self.learn_w and self.learn_phi:

            self.phi = utils.mlp(input_dim=self.state_features,
                                 hidden_dim=cfg.representation.phi_hidden_dim,
                                 output_dim=self.latent_state_features,
                                 hidden_depth=3,
                                 custome=False).to(self.device)

            self.W = utils.mlp(input_dim=self.goal_dim,
                               hidden_dim=cfg.representation.w_hidden_dim,
                               output_dim=self.latent_state_features,
                               hidden_depth=0,
                               weight_norm=self.w_norm).to(self.device)

            # self.W_gt = torch.ones(self.batch_size,
            #                        self.state_features,
            #                        requires_grad=False,
            #                        device=self.device)

        ###################
        # Fix phi, learn w
        ###################
        elif self.learn_w and not (self.learn_phi):
            # `Phi` is a fixed matrix
            self.phi = torch.ones(self.batch_size,
                                  self.state_features,
                                  requires_grad=False,
                                  device=self.device)

            if self.w_model == 'mlp':
                # `W` is a single layer linear MLP
                self.W = utils.mlp(input_dim=2,
                                   hidden_dim=32,
                                   output_dim=self.latent_state_features,
                                   hidden_depth=0).to(self.device)

            elif self.w_model == 'distr':
                ###################
                # Learn w, gaussian
                ###################
                hidden_depth = 2
                hidden_dim = 1024
                log_std_bounds = [-5, 2]
                self.W = DiagGaussianW(goal_dim=2,
                                       hidden_dim=hidden_dim,
                                       hidden_depth=hidden_depth,
                                       log_std_bounds=log_std_bounds).to(
                                           self.device)

        ###################
        # Fix w, learn phi
        ###################
        elif not (self.learn_w) and self.learn_phi:
            # `Phi` is a MLP
            self.phi = utils.mlp(input_dim=self.state_features,
                                 hidden_dim=256,
                                 output_dim=self.latent_state_features,
                                 hidden_depth=2).to(self.device)
            # `W` is a fixed vector
            self.W = torch.ones(self.batch_size,
                                self.latent_state_features,
                                requires_grad=False,
                                device=self.device)

        else:
            raise ValueError('Need to learn either w or phi or both.')

        self.replay_buffer = ReplayBuffer(env_sample.observation_space.shape,
                                          env_sample.action_space.shape,
                                          env_sample_goal_shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device)

        self.video_recorder = VideoRecorder(
            self.work_dir if cfg.save_video else None,
            benchmark=self.benchmark)
        self.step = 0

        if self.w_model == 'distr' or self.w_model == 'mlp':
            if self.learn_w and self.learn_phi:
                params = list(self.phi.parameters()) + list(
                    self.W.parameters())
            elif self.learn_w and not (self.learn_phi):
                params = list(self.W.parameters())
            elif not (self.learn_w) and self.learn_phi:
                params = list(self.phi.parameters())

        elif self.w_model == 'vector':
            # params = list(self.phi.parameters()) + [self.W]
            params = [self.W]

        else:
            raise ValueError('Unsupported W model.')

        self.phi_w_optimizer = torch.optim.Adam(params, lr=3e-4)

        self.train()

    def train(self, training=True):
        self.training = training
        if self.learn_phi:
            self.phi.train(training)
        if self.learn_w:
            if self.w_model == 'distr' or self.w_model == 'mlp':
                self.W.train(training)

    def train_phi_and_w(self, obs, action, next_obs, reward, goal, writer,
                        step):

        if self.learn_w and self.learn_phi:
            # # Calculate phi GT
            # phi = torch.ones_like(self.W_gt)
            # phi[:, 1:3] = obs[:, 0:2]
            # phi[:, 3:] = ((LA.norm(obs[:, 0:2], dim=1, keepdim=True))**2)

            # # Calculate w GT
            # W = torch.ones_like(self.W_gt)
            # W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
            # W[:, 1:3] = 2 * goal
            # W[:, 3:] = -torch.ones_like(W[:, 3:])

            if self.w_model == 'mlp':
                # if self.goal_mode == 'multi_goal':

                if self.phi_model == 'mlp':
                    obs_action = torch.cat([obs, action, next_obs], dim=-1)
                    latent_phi = self.phi(obs_action)

                elif self.phi_model == 'vae':
                    obs_action = torch.cat([obs, action, next_obs], dim=-1)
                    latent_phi, mean, std = self.phi(obs_action)

                latent_phi_w = torch.bmm(
                    latent_phi.view(self.batch_size, 1,
                                    self.latent_state_features),
                    self.W(goal).view(self.batch_size,
                                      self.latent_state_features, 1))
                # else:
                #     obs_wt_goal = torch.ones([self.batch_size,
                #                               4]).to(self.device)
                #     obs_wt_goal[:, 0:2] = obs[:, 0:2]
                #     obs_wt_goal[:, 2:] = obs[:, 4:]

                #     latent_phi_w = torch.bmm(
                #         self.phi(obs_wt_goal).view(self.batch_size, 1,
                #                                    self.latent_state_features),
                #         self.W(goal).view(self.batch_size,
                #                           self.latent_state_features, 1))
                latent_phi_w = latent_phi_w.squeeze(-1)

        elif self.learn_w and not (self.learn_phi):
            # Calculate phi
            self.phi[:, 1:3] = obs[:, 0:2]
            self.phi[:, 3:] = ((LA.norm(obs[:, 0:2], dim=1, keepdim=True))**2)

            # Calculate w GT
            W = torch.ones_like(self.phi)
            W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
            W[:, 1:3] = 2 * goal
            W[:, 3:] = -torch.ones_like(W[:, 3:])

            if self.w_model == 'mlp':
                latent_phi_w = torch.bmm(
                    self.phi.view(self.batch_size, 1, self.state_features),
                    self.W(goal).view(self.batch_size, self.state_features, 1))
                latent_phi_w = latent_phi_w.squeeze(-1)
            elif self.w_model == 'distr':
                w_dist = self.W(goal)
                w_out = w_dist.mean
                latent_phi_w = (self.phi(obs) * w_out).mean(1, keepdim=True)
            elif self.w_model == 'vector':
                latent_phi_w = torch.bmm(
                    self.phi.view(self.batch_size, 1, self.state_features),
                    self.W.view(self.batch_size, self.state_features, 1))

        elif not (self.learn_w) and self.learn_phi:
            # Calculate phi GT
            phi = torch.ones_like(self.W)
            phi[:, 1:3] = obs[:, 0:2]
            phi[:, 3:] = ((LA.norm(obs[:, 0:2], dim=1, keepdim=True))**2)

            # Calculate w
            self.W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
            self.W[:, 1:3] = 2 * goal
            self.W[:, 3:] = -torch.ones_like(self.W[:, 3:])

            obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
            obs_wt_goal[:, 0:2] = obs[:, 0:2]
            obs_wt_goal[:, 2:] = obs[:, 4:]

            latent_phi_w = torch.bmm(
                self.phi(obs_wt_goal).view(self.batch_size, 1,
                                           self.state_features),
                self.W.view(self.batch_size, self.state_features, 1))
            latent_phi_w = latent_phi_w.squeeze(-1)

        if self.phi_model == 'mlp':

            # l2_reg = Variable(torch.FloatTensor(1),
            #                   requires_grad=True).to(self.device)
            # for w_param in self.W.parameters():
            #     l2_reg = l2_reg + w_param.norm(2)
            # l2_reg *= self.reg_lambda
            # Mel: not used, weight norm added via weight_norm layer
            # rep_loss = F.mse_loss(latent_phi_w, reward) + l2_reg
            rep_loss = F.mse_loss(latent_phi_w, reward)
            self.phi_w_optimizer.zero_grad()
            rep_loss.backward()
            self.phi_w_optimizer.step()
            writer.add_scalar('train_rep/loss', rep_loss, step)

        elif self.phi_model == 'vae':
            phi_recon_loss = F.mse_loss(latent_phi, obs_action)
            KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) -
                              std.pow(2)).mean()
            vae_loss = phi_recon_loss + 0.5 * KL_loss

            rep_loss = F.mse_loss(latent_phi_w, reward)

            total_loss = vae_loss + rep_loss

            self.phi_w_optimizer.zero_grad()
            total_loss.backward()
            self.phi_w_optimizer.step()

            writer.add_scalar('train_rep_vae/loss', vae_loss, step)
            writer.add_scalar('train_rep/loss', rep_loss, step)

        mean_squared_error = MeanSquaredError()
        rew_pred = latent_phi_w.detach().cpu()
        mse_error = mean_squared_error(rew_pred, reward.cpu())
        writer.add_scalar('train_rep/mse', mse_error, step)
        # These GTs become meaningless when learning \phi.w jointly!
        # if self.learn_w and self.learn_phi:
        #     w_pred = self.W(goal).detach().cpu()
        #     mse_w_gt = mean_squared_error(w_pred, W.cpu())
        #     writer.add_scalar('train_w_GT/mse', mse_w_gt, step)

        #     if self.goal_mode == 'multi_goal':
        #         obs_action = torch.cat([obs, action, next_obs], dim=-1)
        #         if self.phi_model == 'mlp':
        #             phi_pred = self.phi(obs_action).detach().cpu()
        #         elif self.phi_model == 'vae':
        #             phi_pred, _, _ = self.phi(obs_action)
        #             phi_pred = phi_pred.detach().cpu()
        #     else:
        #         obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
        #         obs_wt_goal[:, 0:2] = obs[:, 0:2]
        #         obs_wt_goal[:, 2:] = obs[:, 4:]
        #         phi_pred = self.phi(obs_wt_goal).detach().cpu()

        #     mse_w_gt = mean_squared_error(phi_pred, phi.cpu())
        #     writer.add_scalar('train_phi_GT/mse', mse_w_gt, step)

        if self.learn_w and not (self.learn_phi):
            w_pred = self.W(goal).detach().cpu()
            mse_w_gt = mean_squared_error(w_pred, W.cpu())
            writer.add_scalar('train_w_GT/mse', mse_w_gt, step)

        elif not (self.learn_w) and self.learn_phi:

            obs_wt_goal = torch.ones([self.batch_size, 4]).to(self.device)
            obs_wt_goal[:, 0:2] = obs[:, 0:2]
            obs_wt_goal[:, 2:] = obs[:, 4:]

            phi_pred = self.phi(obs_wt_goal).detach().cpu()
            mse_w_gt = mean_squared_error(phi_pred, phi.cpu())
            writer.add_scalar('train_phi_GT/mse', mse_w_gt, step)

    def update(self, replay_buffer, logger, writer, step):
        obs, action, reward, next_obs, not_done, not_done_no_max, goal = replay_buffer.sample(
            self.batch_size)
        self.train_phi_and_w(obs, action, next_obs, reward, goal, writer, step)

    def evaluate(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                env_id = random.sample(list(self.all_envs), 1)[0]
                self.env = self.all_envs[env_id]
            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                # With MT2 goal corresponds to the task id itself.
                # assuming tasks are single goal.
                goal = [env_id]

            self.agent_expert.reset()
            self.agent_random.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0
            eps_rewards = []
            eps_reward_preds = []
            expert_agent = np.random.binomial(1, self.p_expert)
            while not done:
                if expert_agent:
                    with utils.eval_mode(self.agent_expert):
                        action = self.agent_expert.act(obs, goal, sample=False)
                else:
                    with utils.eval_mode(self.agent_random):
                        # Select action according to random policy and add clipped noise
                        action = self.agent_random.act(obs, goal, sample=False)
                        action = torch.tensor(action)
                        noise = (torch.randn_like(action) *
                                 self.policy_noise).clamp(
                                     -self.noise_clip, self.noise_clip)
                        action = (action + noise).clamp(
                            self.min_action, self.max_action)
                        action = action.numpy()

                current_obs = obs
                obs, reward, done, extras = self.env.step(action)
                # evaluate reward preds
                if self.w_model == 'mlp':
                    obs_tens = torch.as_tensor(obs, device=self.device).float()
                    goal_tens = torch.as_tensor(goal,
                                                device=self.device).float()
                    current_obs_tens = torch.as_tensor(
                        current_obs, device=self.device).float()
                    act_tens = torch.as_tensor(action,
                                               device=self.device).float()

                    if self.learn_w and self.learn_phi:
                        w_pred = self.W(goal_tens).unsqueeze(0).unsqueeze(0)

                        # if self.goal_mode == 'multi_goal':
                        obs_action = torch.cat(
                            [current_obs_tens, act_tens, obs_tens],
                            dim=-1).unsqueeze(0)
                        if self.phi_model == 'mlp':
                            phi_pred = self.phi(obs_action).unsqueeze(
                                0).unsqueeze(0)
                        elif self.phi_model == 'vae':
                            phi_pred, _, _ = self.phi(obs_action)
                            phi_pred = phi_pred.unsqueeze(0).unsqueeze(0)

                        # else:
                        #     obs_wt_goal = torch.ones([4]).to(self.device)
                        #     obs_wt_goal[0:2] = obs_tens[0:2]
                        #     obs_wt_goal[2:] = obs_tens[4:]
                        #     phi_pred = self.phi(obs_wt_goal).unsqueeze(
                        #         0).unsqueeze(0)

                        latent_phi_w = torch.bmm(
                            phi_pred.view(1, 1, self.latent_state_features),
                            w_pred.view(1, self.latent_state_features, 1))
                        latent_phi_w = float(latent_phi_w.squeeze())

                    elif self.learn_w and not (self.learn_phi):
                        phi_single = torch.ones_like(
                            self.phi)[0].unsqueeze(0).unsqueeze(0)

                        phi_single[0, 0, 1:3] = obs_tens[0:2]
                        phi_single[0, 0, 3] = ((LA.norm(obs_tens[0:2],
                                                        dim=0,
                                                        keepdim=True))**2)
                        if self.w_model == 'mlp':
                            w_pred = self.W(goal_tens).unsqueeze(0).unsqueeze(
                                0)

                            latent_phi_w = torch.bmm(
                                phi_single.view(1, 1,
                                                self.latent_state_features),
                                w_pred.view(1, self.latent_state_features, 1))
                            latent_phi_w = float(latent_phi_w.squeeze())

                        elif self.w_model == 'distr':
                            # TODO: Untested
                            w_dist = self.W(goal)
                            w_out = w_dist.mean
                            latent_phi_w = (self.phi(obs) * w_out).mean(
                                1, keepdim=True)
                        elif self.w_model == 'vector':
                            # TODO: Untested
                            obs_tens = torch.as_tensor(
                                obs, device=self.device).float()
                            latent_phi_w = float((phi_single * self.W).mean(
                                1, keepdim=True)[0][0].detach().cpu().numpy())

                    elif not (self.learn_w) and self.learn_phi:
                        w_single = torch.ones_like(
                            self.W)[0].unsqueeze(0).unsqueeze(0)

                        w_single[0, 0, 1:3] = goal_tens[0:2]
                        w_single[0, 0, 3] = ((LA.norm(goal_tens[0:2],
                                                      dim=0,
                                                      keepdim=True))**2)
                        w_single[:, 3:] = -torch.ones_like(w_single[:, 3:])

                        # when goal is part of observation from env, remove it
                        # for the hard-coded version its not needed.
                        obs_wt_goal = torch.ones([4]).to(self.device)
                        obs_wt_goal[0:2] = obs_tens[0:2]
                        obs_wt_goal[2:] = obs_tens[4:]

                        phi_pred = self.phi(obs_wt_goal).unsqueeze(
                            0).unsqueeze(0)

                        latent_phi_w = torch.bmm(
                            phi_pred.view(1, 1, self.latent_state_features),
                            w_single.view(1, self.latent_state_features, 1))
                        latent_phi_w = float(latent_phi_w.squeeze())

                eps_rewards.append(reward)
                eps_reward_preds.append(latent_phi_w)

                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                self.video_recorder.record(self.env)
                episode_reward += reward
                episode_step += 1
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

            average_episode_reward += episode_reward
            self.video_recorder.save(f'{self.step}.mp4')

        mse_error = mean_squared_error(eps_rewards, eps_reward_preds)
        self.logger.log('eval_seen/reward_pred_mse', mse_error, self.step)
        r2_score_ = r2_score(eps_rewards, eps_reward_preds)
        self.logger.log('eval_seen/reward_r2_score', r2_score_, self.step)
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_seen/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)
        self.writer.add_scalar('eval/episode_reward', average_episode_reward,
                               self.step)

    def evaluate_unseen(self):
        average_episode_reward = 0
        for episode in range(self.cfg.num_eval_episodes):
            env_id = random.sample(list(self.eval_envs_unseen), 1)[0]
            self.env = self.eval_envs_unseen[env_id]

            if self.benchmark == 'dm_suite':
                obs, goal = self.env.reset(query_goal=True)
            elif self.benchmark == 'metaworld':
                obs = self.env.reset()
                goal = [env_id]
            self.agent_expert.reset()
            self.video_recorder.init(enabled=(episode == 0))
            done = False
            episode_reward = 0
            episode_step = 0

            eps_rewards = []
            eps_reward_preds = []
            while not done:
                with utils.eval_mode(self.agent_expert):
                    action = self.agent_expert.act(obs, goal, sample=False)
                current_obs = obs
                obs, reward, done, extras = self.env.step(action)
                if self.benchmark == 'dm_suite':
                    goal = extras['goal']
                episode_reward += reward
                episode_step += 1
                self.video_recorder.record(self.env)
                if self.benchmark == 'metaworld':
                    if (episode_step >= self.env.max_path_length - 1):
                        done = True

                if self.learn_w and self.learn_phi:
                    obs_tens = torch.as_tensor(obs, device=self.device).float()
                    goal_tens = torch.as_tensor(goal,
                                                device=self.device).float()
                    current_obs_tens = torch.as_tensor(
                        current_obs, device=self.device).float()
                    act_tens = torch.as_tensor(action,
                                               device=self.device).float()

                    if self.w_model == 'mlp':

                        w_pred = self.W(goal_tens).unsqueeze(0).unsqueeze(0)

                        if self.goal_mode == 'multi_goal':
                            obs_action = torch.cat(
                                [current_obs_tens, act_tens, obs_tens],
                                dim=-1).unsqueeze(0)
                            if self.phi_model == 'mlp':
                                phi_pred = self.phi(obs_action).unsqueeze(
                                    0).unsqueeze(0)

                        latent_phi_w = torch.bmm(
                            phi_pred.view(1, 1, self.latent_state_features),
                            w_pred.view(1, self.latent_state_features, 1))
                        latent_phi_w = float(latent_phi_w.squeeze())
                else:
                    raise ValueError(
                        'Eval unseen not implemented for this setting.')

                eps_rewards.append(reward)
                eps_reward_preds.append(latent_phi_w)

            average_episode_reward += episode_reward
            self.video_recorder.save(f'unseen_{self.step}.mp4')

        mse_error = mean_squared_error(eps_rewards, eps_reward_preds)
        self.logger.log('eval_unseen/reward_pred_mse', mse_error, self.step)
        r2_score_ = r2_score(eps_rewards, eps_reward_preds)
        self.logger.log('eval_unseen/reward_r2_score', r2_score_, self.step)
        average_episode_reward /= self.cfg.num_eval_episodes
        self.logger.log('eval_unseen/episode_reward', average_episode_reward,
                        self.step)
        self.logger.dump(self.step)

        self.writer.add_scalar('eval_unseen/episode_reward',
                               average_episode_reward, self.step)

    def run(self):
        episode, episode_reward, done = 0, 0, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if done:
                if self.step > 0:
                    self.logger.log('train/duration',
                                    time.time() - start_time, self.step)
                    start_time = time.time()
                    self.logger.dump(
                        self.step, save=(self.step > self.cfg.num_seed_steps))

                # Run evaluation as sanity check
                # also save the phi model
                if episode % self.cfg.eval_frequency == 0:
                    self.logger.log('eval_seen/episode', episode, self.step)
                    self.evaluate()
                    if self.eval_on_unseen:
                        self.evaluate_unseen()
                    self.save(self.model_dir, episode)

                self.logger.log('train/episode_reward', episode_reward,
                                self.step)
                self.writer.add_scalar('train/episode_reward', episode_reward,
                                       self.step)
                if self.goal_mode == 'multi_goal' or self.benchmark == 'metaworld':
                    env_id = random.sample(list(self.all_envs), 1)[0]
                    self.env = self.all_envs[env_id]

                if self.benchmark == 'dm_suite':
                    obs, goal = self.env.reset(query_goal=True)
                elif self.benchmark == 'metaworld':
                    obs = self.env.reset()
                    goal = [env_id]

                self.agent_expert.reset()
                self.agent_random.reset()
                expert_agent = np.random.binomial(1, self.p_expert)
                done = False
                episode_reward = 0
                episode_step = 0
                episode += 1
                self.logger.log('train/episode', episode, self.step)
                self.writer.add_scalar('train/episode', episode, self.step)

            if expert_agent:
                action = self.agent_expert.act(obs, goal, sample=True)
            else:
                # Select action according to policy and add clipped noise
                action = self.agent_random.act(obs, goal, sample=True)
                action = torch.tensor(action)
                noise = (torch.randn_like(action) * self.policy_noise).clamp(
                    -self.noise_clip, self.noise_clip)
                action = (action + noise).clamp(self.min_action,
                                                self.max_action)
                action = action.numpy()

            # run training update
            if self.step >= self.cfg.num_seed_steps:
                self.update(self.replay_buffer, self.logger, self.writer,
                            self.step)

            next_obs, reward, done, extras = self.env.step(action)
            if self.benchmark == 'dm_suite':
                goal = extras['goal']

            if self.benchmark == 'metaworld':
                if (episode_step >= self.env.max_path_length - 1):
                    done = True
            # allow infinite bootstrap
            done = float(done)
            if self.benchmark == 'metaworld':
                done_no_max = 0 if episode_step + 1 == self.env.max_path_length else done
            elif self.benchmark == 'dm_suite':
                done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done
            episode_reward += reward

            self.replay_buffer.add(obs, action, reward, next_obs, done,
                                   done_no_max, goal)

            obs = next_obs
            episode_step += 1
            self.step += 1

    def save(self, model_dir, step):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        if self.learn_w and self.learn_phi:
            torch.save(self.phi.state_dict(),
                       '%s/phi_%s.pt' % (model_dir, step))
            torch.save(self.W.state_dict(), '%s/w_%s.pt' % (model_dir, step))

        if self.learn_w and not (self.learn_phi):
            if self.w_model == 'distr' or self.w_model == 'mlp':
                torch.save(self.W.state_dict(),
                           '%s/w_%s.pt' % (model_dir, step))
            elif self.w_model == 'vector':
                torch.save(self.W, '%s/w_%s.pt' % (model_dir, step))

        elif not (self.learn_w) and self.learn_phi:
            torch.save(self.phi.state_dict(),
                       '%s/phi_%s.pt' % (model_dir, step))

        torch.save(self.phi_w_optimizer.state_dict(),
                   '%s/phi_w_optim_%s.pt' % (model_dir, step))


@hydra.main(config_path="config/", config_name="train_rep")
def main(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    from train_phi_w import Workspace as W
    workspace = W(cfg)
    workspace.run()


if __name__ == '__main__':
    main()