# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pdb  # pylint: disable=unused-import
import logging
import dataclasses
import typing as tp

from url_benchmark import pretrain  # NEEDS TO BE FIRST NON-STANDARD IMPORT (sets up env variables)

import omegaconf as omgcf
import hydra
from hydra.core.config_store import ConfigStore
import torch
import numpy as np
import json
import wandb

from url_benchmark import dmc
from dm_env import specs
from url_benchmark import goals as _goals
from url_benchmark import utils
from url_benchmark.in_memory_replay_buffer_psm import ReplayBuffer # pylint: disable=unused-import
from url_benchmark.logger import Logger, InferenceLogger, ActorLogger
# from url_benchmark.in_memory_replay_buffer_psm import ReplayBuffer
from url_benchmark.video import TrainVideoRecorder, VideoRecorder
from url_benchmark.replay_buffer import EpisodeBatch # pylint: disable=unused-import
from url_benchmark import agent as agents
from url_benchmark.d4rl_benchmark import D4RLReplayBufferBuilder, D4RLWrapper

logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True

from pathlib import Path
import sys
base = Path(__file__).absolute().parents[1]
for fp in [base, base / "url_benchmark"]:
    assert fp.exists()
    if str(fp) not in sys.path:
        sys.path.append(str(fp))

@dataclasses.dataclass
class Config:
    agent: tp.Any
    # misc
    seed: int = 1
    device: str = "cuda"
    save_video: bool = False
    use_tb: bool = False
    use_wandb: bool = True
    use_hiplog: bool = False
    # experiment
    experiment: str = "online"
    # task settings
    task: str = "walker_stand"
    obs_type: str = "states"  # [states, pixels]
    frame_stack: int = 3  # only works if obs_type=pixels
    action_repeat: int = 1  # set to 2 for pixels
    discount: float = 0.99
    future: float = 0.99  # discount of future sampling, future=1 means no future sampling
    goal_space: tp.Optional[str] = None
    append_goal_to_observation: bool = False
    # eval
    num_eval_episodes: int = 10
    custom_reward: tp.Optional[str] = None  # activates custom eval if not None
    final_tests: int = 10
    # checkpoint
    snapshot_at: tp.Tuple[int, ...] = (100000, 200000, 500000, 800000, 1000000, 1500000,
                                       2000000, 3000000, 4000000, 5000000, 9000000, 10000000)
    checkpoint_every: int = 100000
    load_model: tp.Optional[str] = None
    # training
    num_seed_frames: int = 4000
    replay_buffer_episodes: int = 5000
    update_encoder: bool = True
    batch_size: int = omgcf.II("agent.batch_size")

@dataclasses.dataclass
class OfflineConfig(Config):
    # training
    num_grad_steps: int = 1000000
    num_seed_frames: int = 0
    log_every_steps: int = 1000
    # eval
    num_eval_episodes: int = 10
    eval_every_steps: int = 10000
    # dataset
    load_replay_buffer: tp.Optional[str] = None
    expl_agent: str = "proto"
    replay_buffer_dir: str = omgcf.SI("../../../../datasets")  # make sure to update this if you change hydra run dir
    # misc
    experiment: str = "offline"
    reward_free: bool = False


ConfigStore.instance().store(name="workspace_config", node=OfflineConfig)

def make_agent(
    obs_type: str, obs_spec, action_spec, num_expl_steps: int, cfg: omgcf.DictConfig
) -> tp.Union[agents.DiscretePSMAgent]:
    cfg.obs_type = obs_type
    cfg.obs_shape = obs_spec.shape
    cfg.action_shape = (action_spec.num_values, ) if isinstance(action_spec, specs.DiscreteArray) \
        else action_spec.shape
    cfg.num_expl_steps = num_expl_steps
    return hydra.utils.instantiate(cfg)


C = tp.TypeVar("C", bound=Config)


def _update_legacy_class(obj: tp.Any, classes: tp.Sequence[tp.Type[tp.Any]]) -> tp.Any:
    """Updates a legacy class (eg: agent.FBDDPGAgent) to the new
    class (url_benchmark.agent.FBDDPGAgent)

    Parameters
    ----------
    obj: Any
        Object to update
    classes: Types
        Possible classes to update the object to. If current name is one of the classes
        name, the object class will be remapped to it.
    """
    classes = tuple(classes)
    if not isinstance(obj, classes):
        clss = {x.__name__: x for x in classes}
        cls = clss.get(obj.__class__.__name__, None)
        if cls is not None:
            logger.warning(f"Promoting legacy object {obj.__class__} to {cls}")
            obj.__class__ = cls


def _init_eval_sm(workspace: "BaseWorkspace", custom_reward: tp.Optional[_goals.BaseReward] = None) -> agents.MetaDict:
    # if workspace.domain == "grid":
    #     assert isinstance(workspace.agent, agents.DiscreteFBAgent)
    #     return workspace.agent.get_goal_meta(workspace.eval_env.get_goal_obs())
    special = (agents.FBDDPGAgent, agents.SFAgent, agents.SFSVDAgent, agents.APSAgent, agents.NEWAPSAgent, agents.GoalSMAgent, agents.UVFAgent, agents.PSMAgent)
    ag = workspace.agent
    _update_legacy_class(ag, special)
    # we need to check against name for legacy reason when reloading old checkpoints
    # if not isinstance(ag, special) or not len(workspace.replay_loader):
    #     return workspace.agent.init_meta()
    if custom_reward is not None:
        try:  # if the custom reward implements a goal, return it
            goal = custom_reward.get_goal(workspace.cfg.goal_space)
            return workspace.agent.infer_w_goal(goal)
        except Exception:  # pylint: disable=broad-except
            pass
            # we cannot fully type because of the FBBDPG string check :s
            num_steps = workspace.agent.cfg.num_inference_steps  # type: ignore
            obs_list, reward_list = [], []
            batch_size = 0
            while batch_size < num_steps:
                batch = workspace.replay_loader.sample(workspace.cfg.batch_size, custom_reward=custom_reward)
                batch = batch.to(workspace.cfg.device)
                obs_list.append(batch.next_goal if workspace.cfg.goal_space is not None else batch.next_obs)
                reward_list.append(batch.reward)
                batch_size += batch.next_obs.size(0)
            obs, reward = torch.cat(obs_list, 0), torch.cat(reward_list, 0)  # type: ignore
            obs_t, reward_t = obs[:num_steps], reward[:num_steps]
            # phy = workspace.replay_loader._storage["physics"]
            # phy = phy.reshape(-1, phy.shape[-1])
            # back_input = "observation" if workspace.cfg.goal_space is None else "goal"
            # obs = workspace.replay_loader._storage[back_input].reshape(phy.shape[0], -1)  # should have been next obs
            # inds = np.random.choice(phy.shape[0], size=workspace.agent.cfg.num_inference_steps, replace=False)
            # phy, obs = (x[inds, :] for x in (phy, obs))
            # rewards = [[custom_reward.from_physics(p)] for p in phy]
            # obs_t, reward_t = (torch.Tensor(x).float().to(workspace.agent.cfg.device) for x in (obs, rewards))
            return workspace.agent.infer_w_from_obs_and_rewards(workspace.replay_loader, obs_t, reward_t)

    if workspace.cfg.goal_space is not None:
        funcs = _goals.goals.funcs.get(workspace.cfg.goal_space, {})
        if workspace.cfg.task in funcs:
            g = funcs[workspace.cfg.task]()
            metrics = workspace.agent.infer_w_goal(workspace.replay_loader, workspace.inference_logger, g)
            metrics = workspace.agent.distill_actor_ddpg(workspace.replay_loader, workspace.actor_logger, g)
            return g
    return None
    # return workspace.agent.infer_meta(workspace.replay_loader)


class BaseWorkspace(tp.Generic[C]):
    def __init__(self, cfg: C) -> None:
        self.work_dir = Path.cwd()
        print(f'Workspace: {self.work_dir}')
        print(f'Running code in : {Path(__file__).parent.resolve().absolute()}')
        logger.info(f'Workspace: {self.work_dir}')
        logger.info(f'Running code in : {Path(__file__).parent.resolve().absolute()}')

        self.cfg = cfg
        utils.set_seed_everywhere(cfg.seed)
        if not torch.cuda.is_available():
            if cfg.device != "cpu":
                logger.warning(f"Falling back to cpu as {cfg.device} is not available")
                cfg.device = "cpu"
                cfg.agent.device = "cpu"
        self.device = torch.device(cfg.device)
        # goal_spec: tp.Optional[specs.Array] = None
        # if cfg.goal_space is not None:
        #     g = _goals.goals.funcs[cfg.goal_space][cfg.task]()
        #     goal_spec = specs.Array((len(g),), np.float32, 'goal')

        # create envs
        # task = PRIMAL_TASKS[self.domain]
        task = cfg.task
        if task.startswith('point_mass_maze'):
            self.domain = 'point_mass_maze'
        else:
            self.domain = task.split('_', maxsplit=1)[0]

        self.train_env = self._make_env()
        self.eval_env = self._make_env()
        # create agent
        self.agent = make_agent(cfg.obs_type,
                                self.train_env.observation_spec(),
                                self.train_env.action_spec(),
                                cfg.num_seed_frames // cfg.action_repeat,
                                cfg.agent)

        # create logger
        self.logger = Logger(self.work_dir,
                             use_tb=cfg.use_tb,
                             use_wandb=cfg.use_wandb,
                             use_hiplog=cfg.use_hiplog)
        
        self.inference_logger = InferenceLogger(self.work_dir,
                                          use_tb=cfg.use_tb,
                                          use_wandb=cfg.use_wandb,
                                          use_hiplog=cfg.use_hiplog)
        
        self.actor_logger = ActorLogger(self.work_dir,
                                        use_tb=cfg.use_tb,
                                        use_wandb=cfg.use_wandb,
                                        use_hiplog=cfg.use_hiplog)

        if cfg.use_wandb:
            exp_name = '_'.join([
                cfg.experiment, cfg.agent.name, self.domain
            ])
            wandb.init(project="controllable_agent", group=cfg.agent.name, name=exp_name,  # mode="disabled",
                       config=omgcf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))  # type: ignore
            
            wandb.define_metric('train/frame')
            wandb.define_metric('eval/frame')
            wandb.define_metric('inf/frame')
            wandb.define_metric('actor/frame')

            wandb.define_metric('train/*', step_metric='train/frame')
            wandb.define_metric('eval/*', step_metric='eval/frame')
            wandb.define_metric('inf/*', step_metric='inf/frame')
            wandb.define_metric('actor/*', step_metric='actor/frame')

        if cfg.use_hiplog:
            # record config now that it is filled
            parts = ("snapshot", "_type", "_shape", "num_", "save_", "frame", "device", "use_tb", "use_wandb")
            skipped = [x for x in cfg if any(y in x for y in parts)]  # type: ignore
            self.logger.hiplog.flattened({x: y for x, y in cfg.items() if x not in skipped})  # type: ignore
            self.logger.hiplog(workdir=self.work_dir.stem)
            for rm in ("agent/use_tb", "agent/use_wandb", "agent/device"):
                del self.logger.hiplog._content[rm]
            self.logger.hiplog(observation_size=np.prod(self.train_env.observation_spec().shape))

        # # create replay buffer
        # self._data_specs: tp.List[tp.Any] = [self.train_env.observation_spec(),
        #                                      self.train_env.action_spec(), ]
        if cfg.goal_space is not None:
            if cfg.goal_space not in _goals.goal_spaces.funcs[self.domain]:
                raise ValueError(f"Unregistered goal space {cfg.goal_space} for domain {self.domain}")
        #     g = _goals.goals.funcs[cfg.goal_space][cfg.task]()
        #     self._data_specs.append(specs.Array((len(g),), np.float32, 'goal'))
        # self._data_specs.extend([specs.Array((1,), np.float32, 'reward'),
        #                          specs.Array((1,), np.float32, 'discount')])

        self.replay_loader = ReplayBuffer(max_episodes=cfg.replay_buffer_episodes, discount=cfg.discount, future=cfg.future)

        # # create data storage
        # self.replay_storage = ReplayBufferStorage(data_specs, meta_specs,
        #                                           self.work_dir / 'buffer')
        #
        # # create replay buffer
        # self.replay_loader = make_replay_loader(self.replay_storage,
        #                                         cfg.replay_buffer_size,
        #                                         cfg.batch_size,
        #                                         cfg.replay_buffer_num_workers,
        #                                         False, True, cfg.nstep, cfg.discount)

        # create video recorders
        # cam_id = 2 if 'quadruped' not in self.domain else 1
        # cam_id = 1  # centered on subject
        cam_id = 0 if 'quadruped' not in self.domain else 2

        self.video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None,
                                            camera_id=cam_id, use_wandb=self.cfg.use_wandb)

        self.timer = utils.Timer()
        self.global_step = 0
        self.global_episode = 0
        self.eval_rewards_history: tp.List[float] = []
        self._checkpoint_filepath = self.work_dir / "models" / "latest.pt"
        if self._checkpoint_filepath.exists():
            self.load_checkpoint(self._checkpoint_filepath)
        elif cfg.load_model is not None:
            self.load_checkpoint(cfg.load_model, exclude=["replay_loader"])

        self.reward_cls: tp.Optional[_goals.BaseReward] = None
        if self.cfg.custom_reward == "maze_multi_goal":
            self.reward_cls = self._make_custom_reward(seed=self.cfg.seed)

    def _make_env(self) -> dmc.EnvWrapper:
        cfg = self.cfg
        if self.domain == "d4rl":
            import d4rl  # type: ignore # pylint: disable=unused-import
            import gym
            return dmc.EnvWrapper(D4RLWrapper(gym.make(self.cfg.task.split('_')[1])))
        return dmc.make(cfg.task, cfg.obs_type, cfg.frame_stack, cfg.action_repeat, cfg.seed,
                        goal_space=cfg.goal_space, append_goal_to_observation=cfg.append_goal_to_observation)

    @property
    def global_frame(self) -> int:
        return self.global_step * self.cfg.action_repeat

    def _make_custom_reward(self, seed: int) -> tp.Optional[_goals.BaseReward]:
        """Creates a custom reward function if provided in configuration
        else returns None
        """
        if self.cfg.custom_reward is None:
            return None
        return _goals.get_reward_function(self.cfg.custom_reward, seed)

    def eval_maze_goals(self) -> None:
        if isinstance(self.agent, (agents.SFAgent, agents.SFSVDAgent, agents.NEWAPSAgent)) and len(self.replay_loader) > 0:
            self.agent.precompute_cov(self.replay_loader)
        reward_cls = _goals.MazeMultiGoal()
        rewards = list()
        for g in reward_cls.goals:
            goal_rewards = list()
            goal_distances = list()
            metrics = self.agent.infer_w_goal(self.replay_loader, g)
            for episode in range(self.cfg.num_eval_episodes):
                self.video_recorder.init(self.eval_env, enabled=(episode == 0))
                time_step = self.eval_env.reset()
                episode_reward = 0.0
                while not time_step.last():
                    with torch.no_grad(), utils.eval_mode(self.agent):
                        action = self.agent.act(time_step.observation,
                                                g)
                    time_step = self.eval_env.step(action)
                    self.video_recorder.record(self.eval_env)
                    assert isinstance(time_step, dmc.ExtendedGoalTimeStep)
                    step_reward, distance = reward_cls.from_goal(time_step.goal, g)
                    episode_reward += step_reward
                goal_rewards.append(episode_reward)
                goal_distances.append(distance)
                self.video_recorder.save(f'{g}.mp4')
            print(f"goal: {g}, avg_reward: {round(float(np.mean(goal_rewards)), 2)}, avg_distance: {round(float(np.mean(goal_distances)), 5)}")
            rewards.append(float(np.mean(goal_rewards)))
        self.eval_rewards_history.append(float(np.mean(rewards)))
        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            log('episode_reward', self.eval_rewards_history[-1])
            log('step', self.global_step)
            log('episode', self.global_episode)

    def eval(self) -> None:
        step, episode = 0, 0
        eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
        physics_agg = dmc.PhysicsAggregator()
        rewards: tp.List[float] = []
        normalized_scores: tp.List[float] = []
        # meta = _init_eval_meta(self)  # Don't work
        # z_correl = 0.0
        is_d4rl_task = self.cfg.task.split('_')[0] == 'd4rl'
        # actor_success: tp.List[float] = []
        while eval_until_episode(episode):
            time_step = self.eval_env.reset()
            # create custom reward if need be (if field exists)
            seed = 12 * self.cfg.num_eval_episodes + len(rewards)
            custom_reward = self._make_custom_reward(seed=seed)
            # if custom_reward is not None:
            #     meta = i(self, custom_reward)
            # if self.domain == "grid":
            g = _init_eval_sm(self, custom_reward)
                # meta = _init_eval_meta(self)
            total_reward = 0.0
            self.video_recorder.init(self.eval_env, enabled=(episode == 0))
            while not time_step.last():
                with torch.no_grad(), utils.eval_mode(self.agent):
                    action = self.agent.act(time_step.observation)
                time_step = self.eval_env.step(action)
                physics_agg.add(self.eval_env)
                self.video_recorder.record(self.eval_env)
                # for legacy reasons, we need to check the name :s
                # if isinstance(self.agent, agents.FBDDPGAgent):
                #     if self.agent.cfg.additional_metric:
                #         # z_correl += self.agent.compute_z_correl(time_step, meta)
                #         actor_success.extend(self.agent.actor_success)
                if custom_reward is not None:
                    time_step.reward = custom_reward.from_env(self.eval_env)
                total_reward += time_step.reward
                step += 1
            if is_d4rl_task:
                normalized_scores.append(self.eval_env.get_normalized_score(total_reward))
            rewards.append(total_reward)
            episode += 1
            self.video_recorder.save(f'{self.global_frame}.mp4')

        self.eval_rewards_history.append(float(np.mean(rewards)))
        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            if is_d4rl_task:
                log('episode_normalized_score', float(100 * np.mean(normalized_scores)))
            log('episode_reward', self.eval_rewards_history[-1])
            if len(rewards) > 1:
                log('episode_reward#std', float(np.std(rewards)))
            log('episode_length', step * self.cfg.action_repeat / episode)
            log('episode', self.global_episode)
            # log('z_correl', z_correl / episode)
            log('step', self.global_step)
            # if actor_success:
            #     log('actor_sucess', float(np.mean(actor_success)))
            for key, val in physics_agg.dump():
                log(key, val)

    _CHECKPOINTED_KEYS = ('agent', 'global_step', 'global_episode', "replay_loader")

    def save_checkpoint(self, fp: tp.Union[Path, str], exclude: tp.Sequence[str] = ()) -> None:
        logger.info(f"Saving checkpoint to {fp}")
        exclude = list(exclude)
        assert all(x in self._CHECKPOINTED_KEYS for x in exclude)
        fp = Path(fp)
        fp.parent.mkdir(exist_ok=True, parents=True)
        assert isinstance(self.replay_loader, ReplayBuffer), "Is this buffer designed for checkpointing?"
        # this is just a dumb security check to not forget about it
        payload = {k: self.__dict__[k] for k in self._CHECKPOINTED_KEYS if k not in exclude}
        with fp.open('wb') as f:
            torch.save(payload, f, pickle_protocol=4)

    def load_checkpoint(self, fp: tp.Union[Path, str], only: tp.Optional[tp.Sequence[str]] = None, exclude: tp.Sequence[str] = ()) -> None:
        """Reloads a checkpoint or part of it

        Parameters
        ----------
        only: None or sequence of str
            reloads only a specific subset (defaults to all)
        exclude: sequence of str
            does not reload the provided keys
        """
        print(f"loading checkpoint from {fp}")
        fp = Path(fp)
        with fp.open('rb') as f:
            payload = torch.load(f)
        _update_legacy_class(payload, (ReplayBuffer,))
        if isinstance(payload, ReplayBuffer):  # compatibility with pure buffers pickles
            payload = {"replay_loader": payload}
        if only is not None:
            only = list(only)
            assert all(x in self._CHECKPOINTED_KEYS for x in only)
            payload = {x: payload[x] for x in only}
        exclude = list(exclude)
        assert all(x in self._CHECKPOINTED_KEYS for x in exclude)
        for x in exclude:
            payload.pop(x, None)
        for name, val in payload.items():
            logger.info("Reloading %s from %s", name, fp)
            if name == "agent":
                self.agent.init_from(val)
            elif name == "replay_loader":
                _update_legacy_class(val, (ReplayBuffer,))
                assert isinstance(val, ReplayBuffer)
                # pylint: disable=protected-access
                # drop unecessary meta which could make a mess
                val._current_episode.clear()  # make sure we can start over
                val._future = self.cfg.future
                val._discount = self.cfg.discount
                val._max_episodes = len(val._storage["discount"])
                self.replay_loader = val
                self.replay_loader.restore_episode_lengths()
                # print('Loaded replay buffer')
                # print('Number of episodes:', len(val))
                # print(self.replay_loader._episodes_length)
                # print(self.replay_loader._storage)
            else:
                assert hasattr(self, name)
                setattr(self, name, val)
                if name == "global_episode":
                    logger.warning(f"Reloaded agent at global episode {self.global_episode}")

    def finalize(self) -> None:
        print("Running final test", flush=True)
        repeat = self.cfg.final_tests
        if not repeat:
            return

        if self.cfg.custom_reward == "maze_multi_goal":
            eval_hist = self.eval_rewards_history
            rewards = {}
            self.eval_rewards_history = []
            self.cfg.num_eval_episodes = repeat
            self.eval_maze_goals()
            rewards["rewards"] = self.eval_rewards_history
            self.eval_rewards_history = eval_hist  # restore
        else:
            domain_tasks = {
                "cheetah": ['walk', 'walk_backward', 'run', 'run_backward'],
                "quadruped": ['stand', 'walk', 'run', 'jump'],
                "walker": ['stand', 'walk', 'run', 'flip'],
            }
            if self.domain not in domain_tasks:
                return
            eval_hist = self.eval_rewards_history
            rewards = {}
            for name in domain_tasks[self.domain]:
                task = "_".join([self.domain, name])
                self.cfg.task = task
                self.cfg.custom_reward = task  # for the replay buffer
                self.cfg.seed += 1  # for the sake of avoiding similar seeds
                self.eval_env = self._make_env()
                self.eval_rewards_history = []
                self.cfg.num_eval_episodes = 1
                for _ in range(repeat):
                    self.eval()
                rewards[task] = self.eval_rewards_history
        self.eval_rewards_history = eval_hist  # restore
        with (self.work_dir / "test_rewards.json").open("w") as f:
            json.dump(rewards, f)


class Workspace(BaseWorkspace[OfflineConfig]):
    def __init__(self, cfg: OfflineConfig) -> None:
        super().__init__(cfg)
        self.agent.cfg.update_every_steps = 1
        datasets_dir = self.work_dir / cfg.replay_buffer_dir
        replay_dir = datasets_dir.resolve() / self.domain / cfg.expl_agent / 'buffer'
        print(f'replay dir: {replay_dir}')

        # self.replay_loader = ReplayBuffer([],  # self._data_specs, [],  # meta_specs = []
        #                                   cfg.batch_size, cfg.replay_buffer_episodes,
        #                                   cfg.discount, True)

        if self.cfg.load_replay_buffer is not None:
            print("loading Replay from %s", self.cfg.load_replay_buffer)
            self.load_checkpoint(self.cfg.load_replay_buffer, only=["replay_loader"])
            
            # with open(self.cfg.load_replay_buffer, 'rb') as f:
            #     content = torch.load(f)
            #     if isinstance(content, dict):
            #         content = content["replay_loader"]
            #     # assert isinstance(content, ReplayBuffer)
            #     self.replay_loader = content
        else:
            relabeled_replay_file_path = replay_dir / f"../relabeled_replay_{cfg.task}_{cfg.replay_buffer_episodes}.pt"
            if relabeled_replay_file_path.exists():
                print("loading Replay from %s", relabeled_replay_file_path.resolve())
                self.load_checkpoint(relabeled_replay_file_path, only=["replay_loader"])
                # with relabeled_replay_file_path.open('rb') as f:
                #     self.replay_loader = torch.load(f)
            else:
                print("loading and relabeling...")
                goal_func = None if cfg.goal_space is None else _goals.goal_spaces.funcs[self.domain][cfg.goal_space]
                self.replay_loader.load(self.train_env, replay_dir, relabel=True, goal_func=goal_func)
                print("loading is done")
                with relabeled_replay_file_path.open('wb') as f:
                    torch.save(self.replay_loader, f)

        self.replay_loader._future = cfg.future
        self.replay_loader._discount = cfg.discount
        # self.replay_loader._full = True
        self.replay_loader._max_episodes = len(self.replay_loader._storage["discount"])

        if isinstance(self.agent, agents.GoalTD3Agent) and self.agent.cfg.fb_reward:
            self.agent.precompute_cov(self.replay_loader)
            

    def train(self):
        train_until_step = utils.Until(self.cfg.num_grad_steps)
        eval_every_step = utils.Every(self.cfg.eval_every_steps)
        log_every_step = utils.Every(self.cfg.log_every_steps)

        while train_until_step(self.global_step):
            # try to evaluate
            if eval_every_step(self.global_step+1):
                self.logger.log('eval_total_time', self.timer.total_time(), self.global_step)
                if self.cfg.custom_reward == "maze_multi_goal":
                    self.eval_maze_goals()
                else:
                    self.eval()

            if isinstance(self.agent, agents.GoalTD3Agent):
                metrics = self.agent.update(self.replay_loader, self.global_step, self.reward_cls)
            else:
                metrics = self.agent.update(self.replay_loader, self.global_step)
            self.logger.log_metrics(metrics, self.global_step, ty='train')
            if log_every_step(self.global_step):
                elapsed_time, total_time = self.timer.reset()
                with self.logger.log_and_dump_ctx(self.global_step, ty='train') as log:
                    log('fps', self.cfg.log_every_steps / elapsed_time)
                    log('total_time', total_time)
                    log('step', self.global_step)
            self.global_step += 1
            # try to save snapshot
            if self.global_frame in self.cfg.snapshot_at:
                self.save_checkpoint(self._checkpoint_filepath.with_name(f'snapshot_{self.global_frame}.pt'), exclude=["replay_loader"])
            # save checkpoint to reload
            if not self.global_frame % self.cfg.checkpoint_every:
                self.save_checkpoint(self._checkpoint_filepath, exclude=["replay_loader"])
        self.save_checkpoint(self._checkpoint_filepath)  # make sure we save the final checkpoint
        self.finalize()

    # def load_checkpoint(self, fp: tp.Union[Path, str]) -> None:
    #     fp = Path(fp)
    #     with fp.open('rb') as f:
    #         payload = torch.load(f)
    #     self.agent.init_from(payload['agent'])


@hydra.main(config_path='.', config_name='base_config')
def main(cfg: omgcf.DictConfig) -> None:
    workspace = Workspace(cfg)  # type: ignore
    # for _ in range(10):
    #     workspace.eval()
    if isinstance(workspace.agent, agents.DDPGAgent):
        if workspace.agent.reward_free:
            workspace.agent.train_reward(workspace.replay_loader)
    workspace.train()


if __name__ == '__main__':
    main()
