# env_utils.py
import collections
import re
import time

import gymnasium
import numpy as np
import ogbench
from gymnasium.spaces import Box

from utils.datasets import Dataset


class EpisodeMonitor(gymnasium.Wrapper):
    """Environment wrapper to monitor episode statistics."""

    def __init__(self, env, filter_regexes=None):
        super().__init__(env)
        self._reset_stats()
        self.total_timesteps = 0
        self.filter_regexes = filter_regexes if filter_regexes is not None else []

    def _reset_stats(self):
        self.reward_sum = 0.0
        self.episode_length = 0
        self.start_time = time.time()

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)

        # Remove keys that are not needed for logging.
        for filter_regex in self.filter_regexes:
            for key in list(info.keys()):
                if re.match(filter_regex, key) is not None:
                    del info[key]

        self.reward_sum += reward
        self.episode_length += 1
        self.total_timesteps += 1
        info['total'] = {'timesteps': self.total_timesteps}

        if terminated or truncated:
            info['episode'] = {}
            info['episode']['final_reward'] = reward
            info['episode']['return'] = self.reward_sum
            info['episode']['length'] = self.episode_length
            info['episode']['duration'] = time.time() - self.start_time

            if hasattr(self.unwrapped, 'get_normalized_score'):
                info['episode']['normalized_return'] = (
                    self.unwrapped.get_normalized_score(info['episode']['return']) * 100.0
                )

        return observation, reward, terminated, truncated, info

    def reset(self, *args, **kwargs):
        self._reset_stats()
        return self.env.reset(*args, **kwargs)


class FrameStackWrapper(gymnasium.Wrapper):
    """Environment wrapper to stack observations."""

    def __init__(self, env, num_stack):
        super().__init__(env)

        self.num_stack = num_stack
        self.frames = collections.deque(maxlen=num_stack)

        low = np.concatenate([self.observation_space.low] * num_stack, axis=-1)
        high = np.concatenate([self.observation_space.high] * num_stack, axis=-1)
        self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)

    def get_observation(self):
        assert len(self.frames) == self.num_stack
        return np.concatenate(list(self.frames), axis=-1)

    def reset(self, **kwargs):
        ob, info = self.env.reset(**kwargs)
        for _ in range(self.num_stack):
            self.frames.append(ob)
        if 'goal' in info:
            info['goal'] = np.concatenate([info['goal']] * self.num_stack, axis=-1)
        return self.get_observation(), info

    def step(self, action):
        ob, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(ob)
        return self.get_observation(), reward, terminated, truncated, info


def train_val_split(dataset, val_ratio=0.1, seed=0):
    """
    Splits a dictionary-based dataset into train/val dictionaries 
    using a random permutation of indices.

    dataset must be a dict with array-like values (e.g. 'observations', 'actions', etc.)
    """
    N = dataset['observations'].shape[0]
    val_size = int(N * val_ratio)

    rng = np.random.default_rng(seed)
    indices = rng.permutation(N)
    val_indices = indices[:val_size]
    train_indices = indices[val_size:]

    train_dict = {}
    val_dict = {}
    for k, arr in dataset.items():
        train_dict[k] = arr[train_indices]
        val_dict[k] = arr[val_indices]

    return train_dict, val_dict


def make_env_and_datasets(env_name, frame_stack=None, action_clip_eps=1e-5):
    """Make offline RL environment and datasets.

    Args:
        env_name: Name of the environment or dataset.
        frame_stack: Number of frames to stack.
        action_clip_eps: Epsilon for action clipping.

    Returns:
        A tuple of (env, eval_env, train_dataset, val_dataset).
    """

    if 'singletask' in env_name:
        # OGBench tasks -- these have built-in train/val splits
        env, train_dataset, val_dataset = ogbench.make_env_and_datasets(env_name)
        eval_env = ogbench.make_env_and_datasets(env_name, env_only=True)

        env = EpisodeMonitor(env, filter_regexes=['.*privileged.*', '.*proprio.*'])
        eval_env = EpisodeMonitor(eval_env, filter_regexes=['.*privileged.*', '.*proprio.*'])

        # Wrap them as Dataset objects
        train_dataset = Dataset.create(**train_dataset)
        val_dataset = Dataset.create(**val_dataset)

    elif 'antmaze' in env_name and ('diverse' in env_name or 'play' in env_name or 'umaze' in env_name):
        # D4RL AntMaze
        from envs import d4rl_utils
        env = d4rl_utils.make_env(env_name)
        eval_env = d4rl_utils.make_env(env_name)

        dataset = d4rl_utils.get_dataset(env, env_name)
        # Randomly split into 90% train, 10% val
        train_dict, val_dict = train_val_split(dataset, val_ratio=0.1, seed=0)
        train_dataset = Dataset.create(**train_dict)
        val_dataset = Dataset.create(**val_dict)

    elif any(x in env_name for x in ['pen', 'hammer', 'relocate', 'door']):
        # D4RL Adroit
        import d4rl.hand_manipulation_suite  # noqa
        from envs import d4rl_utils

        env = d4rl_utils.make_env(env_name)
        eval_env = d4rl_utils.make_env(env_name)

        dataset = d4rl_utils.get_dataset(env, env_name)
        # Randomly split into 90% train, 10% val
        train_dict, val_dict = train_val_split(dataset, val_ratio=0.1, seed=0)
        train_dataset = Dataset.create(**train_dict)
        val_dataset = Dataset.create(**val_dict)

    else:
        raise ValueError(f'Unsupported environment: {env_name}')

    # Optionally frame-stack both envs
    if frame_stack is not None:
        env = FrameStackWrapper(env, frame_stack)
        eval_env = FrameStackWrapper(eval_env, frame_stack)

    # Do an initial reset so they are ready
    env.reset()
    eval_env.reset()

    # Clip dataset actions if requested
    if action_clip_eps is not None:
        # Train set
        train_dataset = train_dataset.copy(
            add_or_replace=dict(
                actions=np.clip(
                    train_dataset['actions'],
                    -1 + action_clip_eps,
                    1 - action_clip_eps
                )
            )
        )
        # Validation set
        if val_dataset is not None:
            val_dataset = val_dataset.copy(
                add_or_replace=dict(
                    actions=np.clip(
                        val_dataset['actions'],
                        -1 + action_clip_eps,
                        1 - action_clip_eps
                    )
                )
            )

    return env, eval_env, train_dataset, val_dataset
