# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import json
import pdb  # pylint: disable=unused-import
import logging
import dataclasses
import typing as tp
import warnings
from pathlib import Path

warnings.filterwarnings('ignore', category=DeprecationWarning)


os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# if the default egl does not work, you may want to try:
# export MUJOCO_GL=glfw
os.environ['MUJOCO_GL'] = os.environ.get('MUJOCO_GL', 'egl')

import hydra
from hydra.core.config_store import ConfigStore
import numpy as np
import torch
import wandb
import omegaconf as omgcf
# from dm_env import specs

from url_benchmark import dmc
from dm_env import specs
from url_benchmark import utils
from url_benchmark import goals as _goals
from url_benchmark.logger import Logger, InferenceLogger
from url_benchmark.in_memory_replay_buffer_psm import ReplayBuffer
from url_benchmark.video import TrainVideoRecorder, VideoRecorder
from url_benchmark import agent as agents
from url_benchmark.gridworld.env import build_gridworld_task
from url_benchmark.evaluate import evaluate

logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = True
# os.environ['WANDB_MODE']='offline'

# from url_benchmark.dmc_benchmark import PRIMAL_TASKS


# # # Config # # #

@dataclasses.dataclass
class Config:
    agent: tp.Any
    # misc
    seed: int = 1
    device: str = "cuda"
    save_video: bool = False
    use_tb: bool = False
    use_wandb: bool = False
    use_hiplog: bool = False
    visualize_policy: bool = True
    # experiment
    experiment: str = "online"
    # task settings
    task: str = "grid_simple"
    obs_type: str = "states"  # [states, pixels]
    frame_stack: int = 3  # only works if obs_type=pixels
    action_repeat: int = 1  # set to 2 for pixels
    discount: float = 0.99
    future: float = 0.99  # discount of future sampling, future=1 means no future sampling
    goal_space: tp.Optional[str] = None
    append_goal_to_observation: bool = False
    # eval
    num_eval_episodes: int = 10
    custom_reward: tp.Optional[str] = None  # activates custom eval if not None
    final_tests: int = 10
    # checkpoint
    snapshot_at: tp.Tuple[int, ...] = (10000, 20000, 50000, 800000, 1000000, 1500000,
                                       2000000, 3000000, 4000000, 5000000, 9000000, 10000000)
    checkpoint_every: int = 40000
    load_model: tp.Optional[str] = None
    # training
    num_seed_frames: int = 4000
    replay_buffer_episodes: int = 5000
    update_encoder: bool = True
    batch_size: int = omgcf.II("agent.batch_size")


@dataclasses.dataclass
class GridworldConfig(Config):
    # mode
    reward_free: bool = True
    # train settings
    num_train_frames: int = 130000
    # snapshot
    eval_every_frames: int = 40000
    load_replay_buffer: tp.Optional[str] = None
    # replay buffer
    # replay_buffer_num_workers: int = 4
    # nstep: int = omgcf.II("agent.nstep")
    # misc
    save_train_video: bool = False


# loaded as base_pretrain in pretrain.yaml
# we keep the yaml since it's easier to configure plugins from it
ConfigStore.instance().store(name="workspace_config", node=GridworldConfig)


# # # Implem # # #


def make_agent(
    obs_type: str, obs_spec, action_spec, num_expl_steps: int, cfg: omgcf.DictConfig
) -> tp.Union[agents.DiscretePSMAgent]:
    cfg.obs_type = obs_type
    cfg.obs_shape = obs_spec.shape
    cfg.action_shape = (action_spec.num_values, ) if isinstance(action_spec, specs.DiscreteArray) \
        else action_spec.shape
    cfg.num_expl_steps = num_expl_steps
    return hydra.utils.instantiate(cfg)


C = tp.TypeVar("C", bound=Config)


def _update_legacy_class(obj: tp.Any, classes: tp.Sequence[tp.Type[tp.Any]]) -> tp.Any:
    """Updates a legacy class (eg: agent.FBDDPGAgent) to the new
    class (url_benchmark.agent.FBDDPGAgent)

    Parameters
    ----------
    obj: Any
        Object to update
    classes: Types
        Possible classes to update the object to. If current name is one of the classes
        name, the object class will be remapped to it.
    """
    classes = tuple(classes)
    if not isinstance(obj, classes):
        clss = {x.__name__: x for x in classes}
        cls = clss.get(obj.__class__.__name__, None)
        if cls is not None:
            logger.warning(f"Promoting legacy object {obj.__class__} to {cls}")
            obj.__class__ = cls


# def _init_eval_sm(workspace: "BaseWorkspace", custom_reward: tp.Optional[_goals.BaseReward] = None) -> agents.MetaDict:
#     assert isinstance(workspace.agent, agents.DiscretePSMAgent)
#     # metrics = workspace.agent.infer_w(workspace.replay_loader, workspace.inf_logger, workspace.eval_env.get_goal_obs())
#     metrics = workspace.agent.infer_w_pos_neg(workspace.replay_loader, workspace.inf_logger, workspace.eval_env.get_goal_obs(), workspace.eval_env.get_neg_goal_obs())
#     return metrics


class BaseWorkspace(tp.Generic[C]):
    def __init__(self, cfg: C) -> None:
        self.work_dir = Path.cwd()
        print(f'Workspace: {self.work_dir}')
        print(f'Running code in : {Path(__file__).parent.resolve().absolute()}')
        logger.info(f'Workspace: {self.work_dir}')
        logger.info(f'Running code in : {Path(__file__).parent.resolve().absolute()}')

        self.cfg = cfg
        utils.set_seed_everywhere(cfg.seed)
        if not torch.cuda.is_available():
            if cfg.device != "cpu":
                logger.warning(f"Falling back to cpu as {cfg.device} is not available")
                cfg.device = "cpu"
                cfg.agent.device = "cpu"
        self.device = torch.device(cfg.device)
        # goal_spec: tp.Optional[specs.Array] = None
        # if cfg.goal_space is not None:
        #     g = _goals.goals.funcs[cfg.goal_space][cfg.task]()
        #     goal_spec = specs.Array((len(g),), np.float32, 'goal')

        # create envs
        # task = PRIMAL_TASKS[self.domain]
        task = cfg.task
        
        self.domain = task.split('_', maxsplit=1)[0]

        self.train_env = self._make_env()
        self.eval_env = self._make_env()
        self.train_env.load_evaluation_set('new_eval_set.pkl')
        # create agent
        self.agent = make_agent(cfg.obs_type,
                                self.train_env.observation_spec(),
                                self.train_env.action_spec(),
                                cfg.num_seed_frames // cfg.action_repeat,
                                cfg.agent)
        
        self.bellman_ford = agents.BellmanFordAgent(self.eval_env)

        # create logger
        self.logger = Logger(self.work_dir,
                             use_tb=cfg.use_tb,
                             use_wandb=cfg.use_wandb,
                             use_hiplog=cfg.use_hiplog)

        self.inf_logger = InferenceLogger(self.work_dir,
                                          use_tb=cfg.use_tb,
                                          use_wandb=cfg.use_wandb,
                                          use_hiplog=cfg.use_hiplog)

        if cfg.use_wandb:
            exp_name = '_'.join([
                cfg.experiment, cfg.agent.name, self.domain
            ])
            wandb.init(project="controllable_agent", group=cfg.agent.name, name=exp_name,  # mode="disabled",
                       config=omgcf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))  # type: ignore
            
            wandb.define_metric('train/frame')
            wandb.define_metric('eval/frame')
            wandb.define_metric('inf/frame')

            wandb.define_metric('train/*', step_metric='train/frame')
            wandb.define_metric('eval/*', step_metric='eval/frame')
            wandb.define_metric('inf/*', step_metric='inf/frame')

        if cfg.use_hiplog:
            # record config now that it is filled
            parts = ("snapshot", "_type", "_shape", "num_", "save_", "frame", "device", "use_tb", "use_wandb")
            skipped = [x for x in cfg if any(y in x for y in parts)]  # type: ignore
            self.logger.hiplog.flattened({x: y for x, y in cfg.items() if x not in skipped})  # type: ignore
            self.logger.hiplog(workdir=self.work_dir.stem)
            for rm in ("agent/use_tb", "agent/use_wandb", "agent/device"):
                del self.logger.hiplog._content[rm]
            self.logger.hiplog(observation_size=np.prod(self.train_env.observation_spec().shape))


        self.replay_loader = ReplayBuffer(max_episodes=cfg.replay_buffer_episodes, discount=cfg.discount, future=cfg.future)

        cam_id = 2

        self.video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None,
                                            camera_id=cam_id, use_wandb=self.cfg.use_wandb)

        self.timer = utils.Timer()
        self.global_step = 0
        self.global_episode = 0
        self.eval_rewards_history: tp.List[float] = []
        self._checkpoint_filepath = self.work_dir / "models" / "latest.pt"
        # if self._checkpoint_filepath.exists():
        #     self.load_checkpoint(self._checkpoint_filepath)
        # elif cfg.load_model is not None:
        #     self.load_checkpoint(cfg.load_model, exclude=["replay_loader"])

        self.reward_cls: tp.Optional[_goals.BaseReward] = None
        # if self.cfg.custom_reward == "maze_multi_goal":
        #     self.reward_cls = self._make_custom_reward(seed=self.cfg.seed)

    def _make_env(self) -> dmc.EnvWrapper:
        cfg = self.cfg
        return dmc.EnvWrapper(build_gridworld_task(self.cfg.task.split('_')[1]))

    @property
    def global_frame(self) -> int:
        return self.global_step * self.cfg.action_repeat

    def _make_custom_reward(self, seed: int) -> tp.Optional[_goals.BaseReward]:
        """Creates a custom reward function if provided in configuration
        else returns None
        """
        if self.cfg.custom_reward is None:
            return None
        return _goals.get_reward_function(self.cfg.custom_reward, seed)
    
    def eval_bf(self) -> None:
        step, episode = 0, 0
        eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
        physics_agg = dmc.PhysicsAggregator()
        rewards: tp.List[float] = []
        episode_num = 0
        total_pos = 0
        total_neg = 0
        while eval_until_episode(episode):
            seed = 12 * self.cfg.num_eval_episodes + len(rewards)
            total_reward = 0.0
            self.video_recorder.init(self.eval_env, enabled=(episode == 0))
            num_pos, num_neg = evaluate(self.eval_env, self.agent, self.replay_loader, self.inf_logger, 'test', self.reward_cls, self.work_dir, (self.global_frame // self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode)
            # self.agent.plot_q_function_pos_neg(self.work_dir, (self.global_frame // self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs(), self.eval_env.get_neg_goal_obs())
                # num_pos, num_neg = self.agent.plot_q_function(self.work_dir, (self.global_frame / self.cfg.eval_every_frames) * self.cfg.num_eval_episodes + episode, self.eval_env, self.eval_env.get_goal_obs(), self.bellman_ford.prev)
            episode_num += 1
            total_pos += num_pos
            total_neg += num_neg
        self.eval_rewards_history.append(float(np.mean(rewards)))
            
        
        # Log inference metrics: TODO
        # self.logger.log_metrics(inference_metrics, self.global_frame, ty='inf')
        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            log('step', self.global_step)
            log('num_pos', total_pos)
            log('num_neg', total_neg)
            log('num_episodes', episode_num)
            for key, val in physics_agg.dump():
                log(key, val)


class Workspace(BaseWorkspace[GridworldConfig]):
    def __init__(self, cfg: GridworldConfig) -> None:
        super().__init__(cfg)
        self.train_video_recorder = TrainVideoRecorder(self.work_dir if cfg.save_train_video else None,
                                                       camera_id=self.video_recorder.camera_id, use_wandb=self.cfg.use_wandb)

    def train(self) -> None:
        # predicates
        train_until_step = utils.Until(self.cfg.num_train_frames,
                                       self.cfg.action_repeat)
        seed_until_step = utils.Until(self.cfg.num_seed_frames,
                                      self.cfg.action_repeat)
        eval_every_step = utils.Every(self.cfg.eval_every_frames,
                                      self.cfg.action_repeat)
        # if self.cfg.custom_reward is not None:
        #     raise NotImplementedError("Custom reward not implemented in pretrain.py train loop (see anytrain.py)")

        # episode_step, episode_reward= 0, 0.0
        time_step = self.train_env.reset()
        self.train_video_recorder.init(time_step.observation)
        metrics = None
        physics_agg = dmc.PhysicsAggregator()

        # Add data to replay buffer
        state_list = self.train_env.get_state_list()
        action_list = self.train_env.get_action_list()

        for i in range(len(state_list)):
            for j in range(len(action_list)):
                # import ipdb;ipdb.set_trace()
                time_step = self.train_env.reset_at_state(state_list[i])
                self.replay_loader.add(time_step)
                time_step = self.train_env.get_single_transition(action_list[j])
                self.replay_loader.add(time_step)
        # import ipdb;ipdb.set_trace()
        # print("Finished adding data to replay buffer")

        while train_until_step(self.global_step):
            # try to evaluate

            if eval_every_step(self.global_step+1):
                print('Evaluating---------')
                self.logger.log('eval_total_time', self.timer.total_time(),
                                self.global_frame)
                # if self.cfg.custom_reward == "maze_multi_goal":
                #     self.eval_maze_goals()
                # # elif self.domain == "grid":
                # #     self.eval_grid_goals()
                # else:
                total_pos = 0.0
                total_neg = 0.0
                gc_pos = 0.0
                gc_neg = 0.0
                gc_result = 0.0
                rni_pos = 0.0
                rni_neg = 0.0
                rni_result = 0.0
                avg_time = 0.0
                for i in range(10):
                    num_pos, num_neg, inf_time = evaluate(self.train_env, self.agent, self.replay_loader, self.inf_logger, 'goal', self.work_dir, (self.global_frame // self.cfg.eval_every_frames), i)
                    gc_pos += num_pos
                    gc_neg += num_neg
                    avg_time += inf_time
                gc_result = gc_pos / (gc_pos + gc_neg)

                # for i in range(5):
                #     num_pos, num_neg, inf_time = evaluate(self.train_env, self.agent, self.replay_loader, self.inf_logger, 'pos_neg_goal', self.work_dir, (self.global_frame // self.cfg.eval_every_frames), i)
                #     pngc_pos += num_pos
                #     pngc_neg += num_neg
                #     avg_time += inf_time

                # pngc_result = pngc_pos / (pngc_pos + pngc_neg)

                for i in range(10):
                    num_pos, num_neg, inf_time = evaluate(self.train_env, self.agent, self.replay_loader, self.inf_logger, 'rni', self.work_dir, (self.global_frame // self.cfg.eval_every_frames), i)
                    rni_pos += num_pos
                    rni_neg += num_neg
                    avg_time += inf_time
                rni_result = rni_pos / (rni_pos + rni_neg)
                avg_time = avg_time / 20.0
                
                # total_pos = total_pos / 20.0
                # total_neg = total_neg / 20.0
                # avg_time = avg_time / 20.0
                with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
                    log('step', self.global_step)
                    # log('num_pos', total_pos)
                    # log('num_neg', total_neg)
                    log('gc_pos', gc_pos)
                    log('gc_neg', gc_neg)
                    log('gc_success_rate', gc_result)
                    log('rni_pos', rni_pos)
                    log('rni_neg', rni_neg)
                    log('rni_success_rate', rni_result)
                    log('inference_time', avg_time)
                    # log('num_episodes', episode_num)
            # print('Training---------')
            metrics = self.agent.update(self.replay_loader, self.global_step)
            # import ipdb;ipdb.set_trace()
            self.logger.log_metrics(metrics, self.global_frame, ty='train')
            self.logger.dump(self.global_step, 'train')
            self.global_step += 1
            # save checkpoint to reload
            # if not self.global_frame % self.cfg.checkpoint_every:
            #     self.save_checkpoint(self._checkpoint_filepath)
        # self.save_checkpoint(self._checkpoint_filepath)  # make sure we save the final checkpoint
        # self.finalize()


@hydra.main(config_path='.', config_name='base_config_final', version_base="1.1")
def main(cfg: omgcf.DictConfig) -> None:
    # we assume cfg is a PretrainConfig (but actually not really)
    workspace = Workspace(cfg)  # type: ignore
    workspace.train()


if __name__ == '__main__':
    main()
