# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import json
import pdb  # pylint: disable=unused-import
import logging
import dataclasses
import typing as tp
import warnings
from pathlib import Path

warnings.filterwarnings('ignore', category=DeprecationWarning)


os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# if the default egl does not work, you may want to try:
# export MUJOCO_GL=glfw
os.environ['MUJOCO_GL'] = os.environ.get('MUJOCO_GL', 'egl')

import hydra
from hydra.core.config_store import ConfigStore
import numpy as np
import torch
import wandb
import omegaconf as omgcf
# from dm_env import specs

from url_benchmark import dmc
from dm_env import specs
from url_benchmark import utils
from url_benchmark import goals as _goals
from url_benchmark.logger import Logger, InferenceLogger
from url_benchmark.in_memory_replay_buffer_psm import ReplayBuffer
from url_benchmark.video import TrainVideoRecorder, VideoRecorder
from url_benchmark import agent as agents
from url_benchmark.d4rl_benchmark import D4RLReplayBufferBuilder, D4RLWrapper
from url_benchmark.gridworld.env import build_gridworld_task

logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
# os.environ['WANDB_MODE']='offline'

# from url_benchmark.dmc_benchmark import PRIMAL_TASKS


# # # Config # # #

@dataclasses.dataclass
class Config:
    agent: tp.Any
    # misc
    seed: int = 1
    device: str = "cuda"
    save_video: bool = False
    use_tb: bool = False
    use_wandb: bool = True
    use_hiplog: bool = False
    visualize_policy: bool = True
    # experiment
    experiment: str = "online"
    # task settings
    task: str = "grid_simple"
    obs_type: str = "states"  # [states, pixels]
    frame_stack: int = 3  # only works if obs_type=pixels
    action_repeat: int = 1  # set to 2 for pixels
    discount: float = 0.99
    future: float = 0.99  # discount of future sampling, future=1 means no future sampling
    goal_space: tp.Optional[str] = None
    append_goal_to_observation: bool = False
    # eval
    num_eval_episodes: int = 10
    custom_reward: tp.Optional[str] = None  # activates custom eval if not None
    final_tests: int = 10
    # checkpoint
    snapshot_at: tp.Tuple[int, ...] = (10000, 20000, 50000, 800000, 1000000, 1500000,
                                       2000000, 3000000, 4000000, 5000000, 9000000, 10000000)
    checkpoint_every: int = 40000
    load_model: tp.Optional[str] = None
    # training
    num_seed_frames: int = 4000
    replay_buffer_episodes: int = 5000
    update_encoder: bool = True
    batch_size: int = omgcf.II("agent.batch_size")


@dataclasses.dataclass
class GridworldConfig(Config):
    # mode
    reward_free: bool = True
    # train settings
    num_train_frames: int = 2000010
    # snapshot
    eval_every_frames: int = 40000
    load_replay_buffer: tp.Optional[str] = None
    # replay buffer
    # replay_buffer_num_workers: int = 4
    # nstep: int = omgcf.II("agent.nstep")
    # misc
    save_train_video: bool = False


# loaded as base_pretrain in pretrain.yaml
# we keep the yaml since it's easier to configure plugins from it
ConfigStore.instance().store(name="workspace_config", node=GridworldConfig)


# # # Implem # # #


def make_agent(
    obs_type: str, obs_spec, action_spec, num_expl_steps: int, cfg: omgcf.DictConfig
) -> tp.Union[agents.DiscretePSMAgent]:
    cfg.obs_type = obs_type
    cfg.obs_shape = obs_spec.shape
    cfg.action_shape = (action_spec.num_values, ) if isinstance(action_spec, specs.DiscreteArray) \
        else action_spec.shape
    cfg.num_expl_steps = num_expl_steps
    return hydra.utils.instantiate(cfg)


C = tp.TypeVar("C", bound=Config)


def _update_legacy_class(obj: tp.Any, classes: tp.Sequence[tp.Type[tp.Any]]) -> tp.Any:
    """Updates a legacy class (eg: agent.FBDDPGAgent) to the new
    class (url_benchmark.agent.FBDDPGAgent)

    Parameters
    ----------
    obj: Any
        Object to update
    classes: Types
        Possible classes to update the object to. If current name is one of the classes
        name, the object class will be remapped to it.
    """
    classes = tuple(classes)
    if not isinstance(obj, classes):
        clss = {x.__name__: x for x in classes}
        cls = clss.get(obj.__class__.__name__, None)
        if cls is not None:
            logger.warning(f"Promoting legacy object {obj.__class__} to {cls}")
            obj.__class__ = cls


def _init_eval_sm(workspace: "BaseWorkspace", custom_reward: tp.Optional[_goals.BaseReward] = None) -> agents.MetaDict:
    assert isinstance(workspace.agent, agents.DiscretePSMAgent)
    # metrics = workspace.agent.infer_w(workspace.replay_loader, workspace.inf_logger, workspace.eval_env.get_goal_obs())
    metrics = workspace.agent.infer_w_pos_neg(workspace.replay_loader, workspace.inf_logger, workspace.eval_env.get_goal_obs(), workspace.eval_env.get_neg_goal_obs())
    return metrics


class BaseWorkspace(tp.Generic[C]):
    def __init__(self, cfg: C) -> None:
        self.work_dir = Path.cwd()
        print(f'Workspace: {self.work_dir}')
        print(f'Running code in : {Path(__file__).parent.resolve().absolute()}')
        logger.info(f'Workspace: {self.work_dir}')
        logger.info(f'Running code in : {Path(__file__).parent.resolve().absolute()}')

        self.cfg = cfg
        utils.set_seed_everywhere(cfg.seed)
        if not torch.cuda.is_available():
            if cfg.device != "cpu":
                logger.warning(f"Falling back to cpu as {cfg.device} is not available")
                cfg.device = "cpu"
                cfg.agent.device = "cpu"
        self.device = torch.device(cfg.device)
        # goal_spec: tp.Optional[specs.Array] = None
        # if cfg.goal_space is not None:
        #     g = _goals.goals.funcs[cfg.goal_space][cfg.task]()
        #     goal_spec = specs.Array((len(g),), np.float32, 'goal')

        # create envs
        # task = PRIMAL_TASKS[self.domain]
        task = cfg.task
        
        self.domain = task.split('_', maxsplit=1)[0]

        self.train_env = self._make_env()
        self.eval_env = self._make_env()
        # create agent
        self.agent = make_agent(cfg.obs_type,
                                self.train_env.observation_spec(),
                                self.train_env.action_spec(),
                                cfg.num_seed_frames // cfg.action_repeat,
                                cfg.agent)
        
        self.bellman_ford = agents.BellmanFordAgent(self.eval_env)

        # create logger
        self.logger = Logger(self.work_dir,
                             use_tb=cfg.use_tb,
                             use_wandb=cfg.use_wandb,
                             use_hiplog=cfg.use_hiplog)

        self.inf_logger = InferenceLogger(self.work_dir,
                                          use_tb=cfg.use_tb,
                                          use_wandb=cfg.use_wandb,
                                          use_hiplog=cfg.use_hiplog)

        if cfg.use_wandb:
            exp_name = '_'.join([
                cfg.experiment, cfg.agent.name, self.domain
            ])
            wandb.init(project="controllable_agent", group=cfg.agent.name, name=exp_name,  # mode="disabled",
                       config=omgcf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))  # type: ignore
            
            wandb.define_metric('train/frame')
            wandb.define_metric('eval/frame')
            wandb.define_metric('inf/frame')

            wandb.define_metric('train/*', step_metric='train/frame')
            wandb.define_metric('eval/*', step_metric='eval/frame')
            wandb.define_metric('inf/*', step_metric='inf/frame')

        if cfg.use_hiplog:
            # record config now that it is filled
            parts = ("snapshot", "_type", "_shape", "num_", "save_", "frame", "device", "use_tb", "use_wandb")
            skipped = [x for x in cfg if any(y in x for y in parts)]  # type: ignore
            self.logger.hiplog.flattened({x: y for x, y in cfg.items() if x not in skipped})  # type: ignore
            self.logger.hiplog(workdir=self.work_dir.stem)
            for rm in ("agent/use_tb", "agent/use_wandb", "agent/device"):
                del self.logger.hiplog._content[rm]
            self.logger.hiplog(observation_size=np.prod(self.train_env.observation_spec().shape))


        self.replay_loader = ReplayBuffer(max_episodes=cfg.replay_buffer_episodes, discount=cfg.discount, future=cfg.future)

        cam_id = 2

        self.video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None,
                                            camera_id=cam_id, use_wandb=self.cfg.use_wandb)

        self.timer = utils.Timer()
        self.global_step = 0
        self.global_episode = 0
        self.eval_rewards_history: tp.List[float] = []
        self._checkpoint_filepath = self.work_dir / "models" / "latest.pt"
        if self._checkpoint_filepath.exists():
            self.load_checkpoint(self._checkpoint_filepath)
        elif cfg.load_model is not None:
            self.load_checkpoint(cfg.load_model, exclude=["replay_loader"])

        self.reward_cls: tp.Optional[_goals.BaseReward] = None
        # if self.cfg.custom_reward == "maze_multi_goal":
        #     self.reward_cls = self._make_custom_reward(seed=self.cfg.seed)

    def _make_env(self) -> dmc.EnvWrapper:
        cfg = self.cfg
        return dmc.EnvWrapper(build_gridworld_task(self.cfg.task.split('_')[1]))

    @property
    def global_frame(self) -> int:
        return self.global_step * self.cfg.action_repeat

    def _make_custom_reward(self, seed: int) -> tp.Optional[_goals.BaseReward]:
        """Creates a custom reward function if provided in configuration
        else returns None
        """
        if self.cfg.custom_reward is None:
            return None
        return _goals.get_reward_function(self.cfg.custom_reward, seed)
    
    def eval_bf(self) -> None:
        step, episode = 0, 0
        eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
        physics_agg = dmc.PhysicsAggregator()
        rewards: tp.List[float] = []
        episode_num = 0
        total_pos = 0
        total_neg = 0
        while eval_until_episode(episode):
            time_step = self.eval_env.reset()
            # create custom reward if need be (if field exists)
            seed = 12 * self.cfg.num_eval_episodes + len(rewards)
            custom_reward = self._make_custom_reward(seed=seed)
            inference_metrics = _init_eval_sm(self, custom_reward)
            total_reward = 0.0
            self.video_recorder.init(self.eval_env, enabled=(episode == 0))
            if self.cfg.visualize_policy:
                obs_list = []
                action_list = []
            while not time_step.last():
                if self.cfg.visualize_policy:
                    obs_list.append(time_step.observation)
                with torch.no_grad(), utils.eval_mode(self.agent):
                    action = self.agent.act_pos_neg(time_step.observation,
                                            self.eval_env.get_goal_obs(), self.eval_env.get_neg_goal_obs())
                    
                    # action = self.agent.act(time_step.observation,
                    #                         self.eval_env.get_goal_obs())
                time_step = self.eval_env.step(action)
                physics_agg.add(self.eval_env)
                self.video_recorder.record(self.eval_env)
                # for legacy reasons, we need to check the name :s
                if custom_reward is not None:
                    time_step.reward = custom_reward.from_env(self.eval_env)
                total_reward += time_step.reward
                step += 1
                if self.cfg.visualize_policy:
                    action_list.append(action.cpu().numpy()[0])
            rewards.append(total_reward)
            print(self.eval_env.get_goal_obs())
            # self.bellman_ford.solve()
            # self.bellman_ford.plot_bf_function(self.work_dir, (self.global_frame / self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode)
            episode += 1
            self.video_recorder.save(f'{self.global_frame}.mp4')
            if self.cfg.visualize_policy:
                print('Plotting Policy')
                # print(obs_list, )
                # self.eval_env.plot_policy_from_list(self.work_dir, obs_list, action_list, diversity=None, title=f'Policy_{self.global_frame}_{episode}')
                self.agent.plot_q_function_pos_neg(self.work_dir, (self.global_frame // self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs(), self.eval_env.get_neg_goal_obs())
                # num_pos, num_neg = self.agent.plot_q_function(self.work_dir, (self.global_frame / self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs(), self.bellman_ford.prev)
            episode_num += 1
            total_pos += 0#num_pos
            total_neg += 0#num_neg
        self.eval_rewards_history.append(float(np.mean(rewards)))
            
        
        # Log inference metrics: TODO
        # self.logger.log_metrics(inference_metrics, self.global_frame, ty='inf')
        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            log('episode_reward', self.eval_rewards_history[-1])
            if len(rewards) > 1:
                log('episode_reward#std', float(np.std(rewards)))
            log('episode_length', step * self.cfg.action_repeat / episode)
            log('episode', self.global_episode)
            log('step', self.global_step)
            log('num_pos', total_pos)
            log('num_neg', total_neg)
            log('num_episodes', episode_num)
            for key, val in physics_agg.dump():
                log(key, val)

    def eval(self) -> None:
        step, episode = 0, 0
        eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
        physics_agg = dmc.PhysicsAggregator()
        rewards: tp.List[float] = []
        episode_num = 0
        total_pos = 0
        total_neg = 0
        while eval_until_episode(episode):
            time_step = self.eval_env.reset()
            # create custom reward if need be (if field exists)
            seed = 12 * self.cfg.num_eval_episodes + len(rewards)
            custom_reward = self._make_custom_reward(seed=seed)
            inference_metrics = _init_eval_sm(self, custom_reward)
            total_reward = 0.0
            self.video_recorder.init(self.eval_env, enabled=(episode == 0))
            if self.cfg.visualize_policy:
                obs_list = []
                action_list = []
            while not time_step.last():
                if self.cfg.visualize_policy:
                    obs_list.append(time_step.observation)
                with torch.no_grad(), utils.eval_mode(self.agent):
                    action = self.agent.act_pos_neg(time_step.observation,
                                            self.eval_env.get_goal_obs(), self.eval_env.get_neg_goal_obs())
                    
                    # action = self.agent.act(time_step.observation,
                    #                         self.eval_env.get_goal_obs())
                time_step = self.eval_env.step(action)
                physics_agg.add(self.eval_env)
                self.video_recorder.record(self.eval_env)
                # for legacy reasons, we need to check the name :s
                if custom_reward is not None:
                    time_step.reward = custom_reward.from_env(self.eval_env)
                total_reward += time_step.reward
                step += 1
                if self.cfg.visualize_policy:
                    action_list.append(action.cpu().numpy()[0])
            rewards.append(total_reward)
            self.bellman_ford.solve()
            # self.bellman_ford.plot_bf_function(self.work_dir, (self.global_frame / self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode)
            episode += 1
            self.video_recorder.save(f'{self.global_frame}.mp4')
            if self.cfg.visualize_policy:
                print('Plotting Policy')
                # print(obs_list, )
                # self.eval_env.plot_policy_from_list(self.work_dir, obs_list, action_list, diversity=None, title=f'Policy_{self.global_frame}_{episode}')
                self.agent.plot_q_function_pos_neg(self.work_dir, (self.global_frame // self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs(), self.eval_env.get_neg_goal_obs())
                # self.agent.plot_q_function(self.work_dir, (self.global_frame / self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs())
                # num_pos, num_neg = self.agent.plot_q_function(self.work_dir, (self.global_frame / self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs(), self.bellman_ford.prev)
            episode_num += 1
            total_pos += 0#num_pos
            total_neg += 0#num_neg
        self.eval_rewards_history.append(float(np.mean(rewards)))
        
        # Log inference metrics: TODO
        # self.logger.log_metrics(inference_metrics, self.global_frame, ty='inf')
        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            log('episode_reward', self.eval_rewards_history[-1])
            if len(rewards) > 1:
                log('episode_reward#std', float(np.std(rewards)))
            log('episode_length', step * self.cfg.action_repeat / episode)
            log('episode', self.global_episode)
            log('step', self.global_step)
            log('num_pos', total_pos)
            log('num_neg', total_neg)
            log('num_episodes', episode_num)
            for key, val in physics_agg.dump():
                log(key, val)

    _CHECKPOINTED_KEYS = ('agent', 'global_step', 'global_episode', "replay_loader")

    def save_checkpoint(self, fp: tp.Union[Path, str], exclude: tp.Sequence[str] = ()) -> None:
        logger.info(f"Saving checkpoint to {fp}")
        exclude = list(exclude)
        assert all(x in self._CHECKPOINTED_KEYS for x in exclude)
        fp = Path(fp)
        fp.parent.mkdir(exist_ok=True, parents=True)
        assert isinstance(self.replay_loader, ReplayBuffer), "Is this buffer designed for checkpointing?"
        # this is just a dumb security check to not forget about it
        payload = {k: self.__dict__[k] for k in self._CHECKPOINTED_KEYS if k not in exclude}
        with fp.open('wb') as f:
            torch.save(payload, f, pickle_protocol=4)

    def load_checkpoint(self, fp: tp.Union[Path, str], only: tp.Optional[tp.Sequence[str]] = None, exclude: tp.Sequence[str] = ()) -> None:
        """Reloads a checkpoint or part of it

        Parameters
        ----------
        only: None or sequence of str
            reloads only a specific subset (defaults to all)
        exclude: sequence of str
            does not reload the provided keys
        """
        print(f"loading checkpoint from {fp}")
        fp = Path(fp)
        with fp.open('rb') as f:
            payload = torch.load(f)
        _update_legacy_class(payload, (ReplayBuffer,))
        if isinstance(payload, ReplayBuffer):  # compatibility with pure buffers pickles
            payload = {"replay_loader": payload}
        if only is not None:
            only = list(only)
            assert all(x in self._CHECKPOINTED_KEYS for x in only)
            payload = {x: payload[x] for x in only}
        exclude = list(exclude)
        assert all(x in self._CHECKPOINTED_KEYS for x in exclude)
        for x in exclude:
            payload.pop(x, None)
        for name, val in payload.items():
            logger.info("Reloading %s from %s", name, fp)
            if name == "agent":
                self.agent.init_from(val)
            elif name == "replay_loader":
                _update_legacy_class(val, (ReplayBuffer,))
                assert isinstance(val, ReplayBuffer)
                # pylint: disable=protected-access
                # drop unecessary meta which could make a mess
                val._current_episode.clear()  # make sure we can start over
                val._future = self.cfg.future
                val._discount = self.cfg.discount
                val._max_episodes = len(val._storage["discount"])
                self.replay_loader = val
            else:
                assert hasattr(self, name)
                setattr(self, name, val)
                if name == "global_episode":
                    logger.warning(f"Reloaded agent at global episode {self.global_episode}")


class Workspace(BaseWorkspace[GridworldConfig]):
    def __init__(self, cfg: GridworldConfig) -> None:
        super().__init__(cfg)
        self.train_video_recorder = TrainVideoRecorder(self.work_dir if cfg.save_train_video else None,
                                                       camera_id=self.video_recorder.camera_id, use_wandb=self.cfg.use_wandb)
        if not self._checkpoint_filepath.exists():  # don't relay if there is a checkpoint
            if cfg.load_replay_buffer is not None:
                if self.cfg.task.split('_')[0] == "d4rl":
                    d4rl_replay_buffer_builder = D4RLReplayBufferBuilder()
                    self.replay_storage = d4rl_replay_buffer_builder.prepare_replay_buffer_d4rl(self.train_env, self.agent.init_meta(), self.cfg)
                    self.replay_loader = self.replay_storage
                else:
                    self.load_checkpoint(cfg.load_replay_buffer, only=["replay_loader"])

    def train(self) -> None:
        # predicates
        train_until_step = utils.Until(self.cfg.num_train_frames,
                                       self.cfg.action_repeat)
        seed_until_step = utils.Until(self.cfg.num_seed_frames,
                                      self.cfg.action_repeat)
        eval_every_step = utils.Every(self.cfg.eval_every_frames,
                                      self.cfg.action_repeat)
        # if self.cfg.custom_reward is not None:
        #     raise NotImplementedError("Custom reward not implemented in pretrain.py train loop (see anytrain.py)")

        # episode_step, episode_reward= 0, 0.0
        time_step = self.train_env.reset()
        self.train_video_recorder.init(time_step.observation)
        metrics = None
        physics_agg = dmc.PhysicsAggregator()

        # Add data to replay buffer
        state_list = self.train_env.get_state_list()
        action_list = self.train_env.get_action_list()

        for i in range(len(state_list)):
            for j in range(len(action_list)):
                # import ipdb;ipdb.set_trace()
                time_step = self.train_env.reset_at_state(state_list[i])
                self.replay_loader.add(time_step)
                time_step = self.train_env.get_single_transition(action_list[j])
                self.replay_loader.add(time_step)
        # import ipdb;ipdb.set_trace()
        # print("Finished adding data to replay buffer")

        while train_until_step(self.global_step):
            # try to evaluate
            if eval_every_step(self.global_step+1):
                print('Evaluating---------')
                self.logger.log('eval_total_time', self.timer.total_time(),
                                self.global_frame)
                # if self.cfg.custom_reward == "maze_multi_goal":
                #     self.eval_maze_goals()
                # # elif self.domain == "grid":
                # #     self.eval_grid_goals()
                # else:
                self.eval()
            if isinstance(self.agent, agents.GoalTD3Agent) and isinstance(self.reward_cls, _goals.MazeMultiGoal):
                metrics = self.agent.update(self.replay_loader, self.global_step, self.reward_cls)
            else:
                metrics = self.agent.update(self.replay_loader, self.global_step)
            # import ipdb;ipdb.set_trace()
            self.logger.log_metrics(metrics, self.global_frame, ty='train')
            self.logger.dump(self.global_step, 'train')
            self.global_step += 1
            # save checkpoint to reload
            # if not self.global_frame % self.cfg.checkpoint_every:
            #     self.save_checkpoint(self._checkpoint_filepath)
        # self.save_checkpoint(self._checkpoint_filepath)  # make sure we save the final checkpoint
        # self.finalize()


@hydra.main(config_path='.', config_name='base_config', version_base="1.1")
def main(cfg: omgcf.DictConfig) -> None:
    # we assume cfg is a PretrainConfig (but actually not really)
    workspace = Workspace(cfg)  # type: ignore
    workspace.train()


if __name__ == '__main__':
    main()
