from bigym.bigym_env import BiGymEnv, CONTROL_FREQUENCY_MAX
from bigym.action_modes import JointPositionActionMode
from robobase.utils import DemoEnv, add_demo_to_replay_buffer
from robobase.envs.utils.bigym_utils import TASK_MAP
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
from robobase.envs.env import EnvFactory
from robobase.envs.utils.bigym_utils import ErrorCalculator
from robobase.envs.wrappers import (
    RescaleFromTanhWithMinMax,
    OnehotTime,
    ActionSequence,
    AppendDemoInfo,
    FrameStack,
    ConcatDim,
    RecedingHorizonControl,
)
from omegaconf import DictConfig
from bigym.utils.observation_config import ObservationConfig, CameraConfig
from bigym.action_modes import PelvisDof
import multiprocessing as mp
import logging
import numpy as np

from demonstrations.demo import DemoStep
from demonstrations.demo_store import DemoStore
from demonstrations.utils import Metadata

from typing import List, Dict, Tuple, Callable
import copy
from tqdm import tqdm

from third_party.demonstrations.build.lib.demonstrations.demo_player import DemoPlayer

UNIT_TEST = False


def rescale_demo_actions(
    rescale_fn: Callable, demos: List[List[DemoStep]], cfg: DictConfig
):
    """Rescale actions in demonstrations to [-1, 1] Tanh space.
    This is because RoboBase assumes everything to be in [-1, 1] space.

    Args:
        rescale_fn: callable that takes info containing demo action and cfg and
            outputs the rescaled action
        demos: list of demo episodes whose actions are raw, i.e., not scaled
        cfg: Configs

    Returns:
        List[Demo]: list of demo episodes whose actions are rescaled
    """
    for demo in demos:
        for step in demo:
            info = step.info
            if "demo_action" in info:
                # Rescale demo actions
                info["demo_action"] = rescale_fn(info, cfg)
    return demos


def _task_name_to_env_class(task_name: str) -> type[BiGymEnv]:
    return TASK_MAP[task_name]


class EEFWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = copy.deepcopy(env.observation_space)
        self.observation_space.spaces.update({
            "eef": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float64),
        })

    def observation(self, observation):
        if "eef" not in observation:
            eef_data = ErrorCalculator._read_eef(self.env, return_array=True)
            observation["eef"] = eef_data
        return observation


def _worker_add_eef(demos, cfg, high_gain):
    bigym_class = _task_name_to_env_class(cfg.env.task_name)
    camera_configs = [
        CameraConfig(
            name=camera_name,
            rgb=True,
            depth=False,
            resolution=cfg.visual_observation_shape,
        )
        for camera_name in cfg.env.cameras
    ]
    
    if cfg.env.enable_all_floating_dof:
        action_mode = JointPositionActionMode(
            absolute=cfg.env.action_mode == "absolute",
            block_until_reached=getattr(cfg, "block_until_reached", False),
            floating_base=True,
            floating_dofs=[PelvisDof.X, PelvisDof.Y, PelvisDof.Z, PelvisDof.RZ],
        )
    else:
        action_mode = JointPositionActionMode(
            absolute=cfg.env.action_mode == "absolute",
            block_until_reached=getattr(cfg, "block_until_reached", False),
            floating_base=True,
        )
    # turn off rendering for demo playback
    raw_env = bigym_class(
        render_mode=None, # turn off render
        action_mode=action_mode,
        observation_config=ObservationConfig(
            cameras=[], # turn off render
            proprioception=True,
            privileged_information=False if cfg.pixels else True,
        ),
        control_frequency=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate,
        high_gain = high_gain,
    )
    
    for demo in demos:
        timesteps = demo.timesteps
        raw_env.reset(seed=demo.seed)
        # "left_position", "left_quat", "left_grip", "right_position", "right_quat", "right_grip"
        eef = ErrorCalculator._read_eef(raw_env, return_array=True)
        timesteps[0].observation['eef'] = eef
        for step in timesteps:
            action = step.executed_action
            _, r, _, _, _ = raw_env.step(action)
            eef = ErrorCalculator._read_eef(raw_env, return_array=True)
            step.observation['eef'] = eef
    raw_env.close()
    return demos


class BiGymEnvFactory(EnvFactory):
    HIGH_GAIN=False
    def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=False, chunk_env=False, ignore_qvel=False):
        # last two are grippers
        assert cfg.demos != 0
        assert cfg.action_repeat == 1

        action_space = copy.deepcopy(env.action_space)
        observation_space = copy.deepcopy(env.observation_space)

        env = RescaleFromTanhWithMinMax(
            env=env,
            action_stats=self._action_stats,
            min_max_margin=cfg.min_max_margin,
        )
        obs_stats = None
        if cfg.norm_obs:
            obs_stats = self._obs_stats

        # We normalize the low dimensional observations in the ConcatDim wrapper.
        # This is to be consistent with the original ACT implementation.
        env = EEFWrapper(env) # add eef in obs
        env = ConcatDim(
            env,
            shape_length=1,
            dim=-1,
            new_name="low_dim_state",
            norm_obs=cfg.norm_obs,
            obs_stats=obs_stats,
            keys_to_ignore=["proprioception_floating_base_actions", "eef"],
            ignore_qvel=ignore_qvel,        # drop qvel from proprioception
            proprio_key="proprioception",
            qpos_dim=30,             # optional; defaults to half the vector
        )
        if cfg.use_onehot_time_and_no_bootstrap:
            env = OnehotTime(env, cfg.env.episode_length)
        if not demo_env:
            env = FrameStack(env, cfg.frame_stack)
        env = TimeLimit(
            env,
            cfg.env.episode_length // cfg.env.demo_down_sample_rate,
        )

        if chunk_env:
            if not train:
                env = RecedingHorizonControl(
                    env,
                    cfg.action_sequence,
                    cfg.env.episode_length // (cfg.env.demo_down_sample_rate),
                    cfg.execution_length,
                    temporal_ensemble=cfg.temporal_ensemble,
                    gain=cfg.temporal_ensemble_gain,
                )
            else:
                env = ActionSequence(
                    env,
                    cfg.action_sequence,
                )

        env = AppendDemoInfo(env)

        if return_raw_spaces:
            return env, action_space, observation_space
        else:
            return env

    def _create_env(self, cfg: DictConfig, work_dir: str = None, high_gain=None) -> BiGymEnv:
        bigym_class = _task_name_to_env_class(cfg.env.task_name)
        camera_configs = [
            CameraConfig(
                name=camera_name,
                rgb=True,
                depth=False,
                resolution=cfg.visual_observation_shape,
            )
            for camera_name in cfg.env.cameras
        ]
        if cfg.env.enable_all_floating_dof:
            action_mode = JointPositionActionMode(
                absolute=cfg.env.action_mode == "absolute",
                block_until_reached=getattr(cfg, "block_until_reached", False),
                floating_base=True,
                floating_dofs=[PelvisDof.X, PelvisDof.Y, PelvisDof.Z, PelvisDof.RZ],
            )
        else:
            action_mode = JointPositionActionMode(
                absolute=cfg.env.action_mode == "absolute",
                block_until_reached=getattr(cfg, "block_until_reached", False),
                floating_base=True,
            )

        
        env_high_gain = high_gain if high_gain is not None else self.HIGH_GAIN
        return bigym_class(
            render_mode=cfg.env.render_mode,
            action_mode=action_mode,
            observation_config=ObservationConfig(
                cameras=camera_configs if cfg.pixels else [],
                proprioception=True,
                privileged_information=False if cfg.pixels else True,
            ),
            control_frequency=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate,
            high_gain = env_high_gain,
        )

    def make_train_env(self, cfg: DictConfig) -> gym.vector.VectorEnv:
        vec_env_class = gym.vector.SyncVectorEnv
        return vec_env_class(
            [
                lambda: self._wrap_env(
                    self._create_env(cfg),
                    cfg,
                    demo_env=False,
                    train=True,
                    chunk_env=getattr(cfg, "chunk_env", True),
                    ignore_qvel=getattr(cfg, "ignore_qvel", False),
                )
                for _ in range(cfg.num_train_envs)
            ],
        )

    def make_eval_env(self, cfg: DictConfig, work_dir:str) -> gym.Env:
        env, self._action_space, self._observation_space = self._wrap_env(
            env=self._create_env(cfg, work_dir),
            cfg=cfg,
            demo_env=False,
            train=False,
            return_raw_spaces=True,
            chunk_env=getattr(cfg, "chunk_env", True),
            ignore_qvel=getattr(cfg, "ignore_qvel", False),
        )
        return env


    def _get_demo_fn(self, cfg: DictConfig, num_demos: int):
        demos = []

        logging.info("Start to load demos.")
        env = self._create_env(cfg)

        demo_store = DemoStore()
        if np.isinf(num_demos):
            num_demos = -1

        demos = demo_store.get_demos(
            Metadata.from_env(env),
            amount=num_demos,
            frequency=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate,
        )

        for demo in demos:
            for ts in demo.timesteps:
                ts.observation = {
                    k: np.array(v, dtype=np.float32) for k, v in ts.observation.items()
                }

        env.close()
        logging.info("Finished loading demos.")

        actuated_names = [mjcf.name for mjcf in env.robot.limb_actuators]
        prop_names = {j.mjcf.name: i for i, j in enumerate(env.robot._joints)}
        self.prop_to_actuated_idx = np.array([prop_names[a_n] for a_n in actuated_names])

        return demos

    def extract_action_from_observation(self, observation):
        prop_qpos = observation['proprioception'][:30]
        actuated_qpos = prop_qpos[np.array(self.prop_to_actuated_idx)]
        fb_qpos = observation['proprioception_floating_base']
        gripper_qpos = observation['proprioception_grippers']
        action = np.concatenate([fb_qpos, actuated_qpos, gripper_qpos])
        return action

    def replace_action_in_observation(self, demos):
        for demo in demos:
            for i in range(len(demo.timesteps)-1):
                next_obs = demo.timesteps[i+1].observation
                action = self.extract_action_from_observation(next_obs)
                demo.timesteps[i].info['demo_action'] = action
        return demos

    def collect_or_fetch_demos(self, cfg: DictConfig, num_demos: int):
        demos = self._get_demo_fn(cfg, num_demos)
        if getattr(cfg, "replace_demo_action", False):
            print("NOTE: Replacing demo actions with those extracted from observations.")
            demos = self.replace_action_in_observation(demos)
            assert(np.all(demos[0].timesteps[0].info['demo_action']==self.extract_action_from_observation(demos[0].timesteps[1].observation)))
        self._raw_demos = demos
        
        if cfg.add_eef:
            self.add_eef_in_demos(cfg, None)

        self._action_stats = self._compute_action_stats(cfg, self._raw_demos)
        self._obs_stats = self._compute_obs_stats(cfg, self._raw_demos)
        return self._raw_demos

    def add_eef_in_demos(self, cfg, work_dir, num_workers=8):
        if num_workers <= 1:
            self._raw_demos = _worker_add_eef(self._raw_demos, cfg, self.HIGH_GAIN)
        else:
            print(f"Adding EEF in demos with {num_workers} workers...")
            from multiprocessing import Pool
            chunk_size = int(np.ceil(len(self._raw_demos) / num_workers))
            chunks = [self._raw_demos[i:i + chunk_size] for i in range(0, len(self._raw_demos), chunk_size)]
            
            with Pool(num_workers) as p:
                results = p.starmap(_worker_add_eef, [(chunk, cfg, self.HIGH_GAIN) for chunk in chunks])
            
            self._raw_demos = [demo for chunk in results for demo in chunk]

    def post_collect_or_fetch_demos(self, cfg: DictConfig, work_dir: str):
        demo_list = [demo.timesteps for demo in self._raw_demos]

        demo_list = rescale_demo_actions(
            self._rescale_demo_action_helper, demo_list, cfg
        )
        self._demos = self._demo_to_steps(cfg, demo_list)
        return self._demos


    def plot_data(self,fig_name, qpos, target_qpos, work_dir=None): 
        import matplotlib.pyplot as plt      
        qpos = np.array(qpos)
        target_qpos = np.array(target_qpos)
        num_dims = qpos.shape[1]  
        timestep = qpos.shape[0]  
        
        cols = 4  
        rows = (num_dims + cols - 1) // cols  

        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) 
        axes = axes.flatten()  
        
        for i in range(num_dims):
            axes[i].plot(range(timestep), qpos[:, i], label=f"real qpos {i+1}")
            axes[i].plot(range(timestep), target_qpos[:, i], label=f"target qpos {i+1}")
            axes[i].set_title(f"Dimension {i+1}")
            axes[i].set_xlabel("Timestep")
            axes[i].set_ylabel("Value")
            axes[i].legend()
        
        for j in range(num_dims, len(axes)):
            axes[j].axis("off")
        
        plt.tight_layout()  
        if work_dir is not None:
            save_dir = work_dir/f'plot'
            save_dir.mkdir(parents=True, exist_ok=True) 
            save_path = save_dir / fig_name
            fig.savefig(save_path)
        plt.close(fig)
        
        
    def load_demos_into_replay(self, cfg: DictConfig, buffer, is_demo_buffer, labels = None):
        """See base class for documentation."""
        assert hasattr(self, "_demos"), (
            "There's no _demo attribute inside the factory, "
            "Check `collect_or_fetch_demos` is called before calling this method."
        )
        if is_demo_buffer:
            # Filter successful demonstrations
            demos = []
            for i, demo in enumerate(self._demos):
                successful = (demo[0][-1]["demo"] == 1)
                if successful:
                    demos.append(demo)
                else:
                    print(f"Skipping failed demonstration {i}")
                    continue
        else:
            demos = self._demos
        # check labels and demos match
        if is_demo_buffer and labels is not None:
            for i in range(len(demos)):
                print(len(labels[i]), len(demos[i]))
                assert len(labels[i]) == len(demos[i])
        demo_env = self._wrap_env(
            DemoEnv(
                copy.deepcopy(demos), self._action_space, self._observation_space
            ),
            cfg,
            demo_env=True,
            train=False,
            ignore_qvel=getattr(cfg, "ignore_qvel", False),
        )
        for _ in range(len(demos)):
            add_demo_to_replay_buffer(demo_env, buffer)

    def _demo_to_steps(
        self, cfg: DictConfig, demo_list: List[List[DemoStep]]
    ) -> List[DemoStep]:
        ret_demos = []

        for demo in demo_list:
            cur_demo = []
            last_timestep = False
            
            # Detect whether this demo is successful or not
            rewards = []
            for step in demo:
                reward = step.reward
                rewards.append(reward)
            successful_demo = sum(rewards) > 0.25
            
            for i, step in enumerate(demo):
                step.info.update({"demo": int(successful_demo)})
                if i == 0:
                    cur_demo.append((step.observation, step.info))
                else:
                    term, trunc = step.termination, step.truncation
                    reward = step.reward
                    if i == len(demo) - 1 or reward > 0:
                        if not (term or trunc):
                            term = False
                            trunc = True
                        last_timestep = True

                    cur_demo.append((step.observation, reward, term, trunc, step.info))
                if last_timestep:
                    break
            ret_demos.append(cur_demo)

        return ret_demos

    def _compute_action_stats(
        self, cfg: DictConfig, demos: List[List[DemoStep]]
    ) -> Dict:
        actions = []
        for demo in demos:
            for step in demo.timesteps:
                info = step.info
                if "demo_action" in info:
                    actions.append(info["demo_action"])
        actions = np.stack(actions)

        mean, std, gmax, gmin = self._get_gripper_action_stats(cfg)
        action_mean = np.hstack([np.mean(actions, 0)[:-2], mean, mean])
        action_std = np.hstack([np.std(actions, 0)[:-2], std, std])
        action_max = np.hstack([np.max(actions, 0)[:-2], gmax, gmax])
        action_min = np.hstack([np.min(actions, 0)[:-2], gmin, gmin])
        action_stats = {
            "mean": action_mean,
            "std": action_std,
            "max": action_max,
            "min": action_min,
        }
        return action_stats

    def _compute_obs_stats(self, cfg: DictConfig, demos: List[List[DemoStep]]) -> Dict:
        import torch
        # Using torch to accelerate calculation
        obs = []
        for demo in demos:
            for step in demo.timesteps:
                obs.append(step.observation)
        
        keys = obs[0].keys()
        obs_torch = {key: torch.stack([torch.tensor(o[key]) for o in obs], dim=0) for key in keys}

        obs_mean = {key: torch.mean(obs_torch[key], dim=0) for key in keys}
        obs_std = {key: torch.std(obs_torch[key], dim=0) for key in keys}
        obs_min = {key: torch.min(obs_torch[key], dim=0).values for key in keys}
        obs_max = {key: torch.max(obs_torch[key], dim=0).values for key in keys}
        
        obs_mean_np = {key: value.numpy() for key, value in obs_mean.items()}
        obs_std_np = {key: value.numpy() for key, value in obs_std.items()}
        obs_min_np = {key: value.numpy() for key, value in obs_min.items()}
        obs_max_np = {key: value.numpy() for key, value in obs_max.items()}
        
        obs_stats = {
            "mean": obs_mean_np,
            "std": obs_std_np,
            "max": obs_max_np,
            "min": obs_min_np,
        }
        
        return obs_stats

    def _get_gripper_action_stats(
        self, cfg: DictConfig
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        if cfg.env.action_mode in ["absolute", "delta"]:
            return (0.5, 0.25, 1, 0)
        else:
            raise NotImplementedError("Unsupported action mode.")

    def _rescale_demo_action_helper(self, info, cfg: DictConfig):
        return RescaleFromTanhWithMinMax.transform_to_tanh(
            info["demo_action"],
            action_stats=self._action_stats,
            min_max_margin=cfg.min_max_margin,
        )
