from functools import partial
from typing import Optional

import numpy as onp
import jax
import jax.numpy as jnp
import gym
import d4rl

from opelab.core.baselines.pgd.util import *


# dataset
def load_dataset(args, normalize, val_split=0.0):
    """Load and normalize train and validation datasets"""
    trajs, val_trajs = _load_dataset(args, val_split=val_split)
    if normalize:
        trajs, trajectory_norm_stats = _normalize_dataset(
            trajs._replace(action=jnp.arctanh(jnp.clip(trajs.action, -0.999, 0.999)))
        )
        if val_trajs is not None:
            val_trajs = _normalize_from_stats(
                val_trajs._replace(
                    action=jnp.arctanh(jnp.clip(val_trajs.action, -0.999, 0.999))
                ),
                trajectory_norm_stats,
            )
    obs_dim, num_actions = trajs.obs.shape[-1], trajs.action.shape[-1]
    if normalize:
        return trajs, val_trajs, trajectory_norm_stats, (obs_dim, num_actions)
    return trajs, val_trajs, (obs_dim, num_actions)


def _load_d4rl_data(args):
    """Load D4RL dataset in Jax Numpy format, split on done flags."""
    # --- Load data and convert to Jax Numpy ---
    dataset = gym.make(args.dataset_name).get_dataset()
    trajs = {
        attr: dataset[attr][:-1]
        for attr in ["observations", "actions", "rewards", "terminals", "timeouts"]
    }
    trajs["next_observations"] = dataset["observations"][1:]
    trajs["done"] = jnp.logical_or(dataset["terminals"][:-1], dataset["timeouts"][:-1])
    trajs = jax.tree_map(jnp.array, trajs)

    # --- Split data on terminal or timeout flags ---
    split_idxs = jnp.argwhere(trajs["done"]).squeeze() + 1
    # Omit final index if present
    if split_idxs[-1] == len(trajs["done"]):
        split_idxs = split_idxs[:-1]
    trajs = jax.tree_map(lambda x: jnp.array_split(x, split_idxs), trajs)

    # --- Return list of episode dicts ---
    result = [{k: v[i] for k, v in trajs.items()} 
              for i in range(len(split_idxs) + 1)]
    print(result[0].keys())
    return result


def _load_dataset(args, val_split=0.0):
    """
    Loads flattened D4RL dataset.

    Episodes are concatenated together,
    then split into args.trajectory_length around done flags.
    """
    print(f"Loading D4RL dataset {args.dataset_name}", end="...")

    # --- Load training and validation episodes ---
    eps = _load_d4rl_data(args)
    if val_split > 0.0:
        num_val_eps = int(val_split * len(eps))
        print(
            f"found {len(eps)} episodes, splitting off {num_val_eps} for validation.",
        )
        assert (
            num_val_eps > 0
        ), f"Val split {val_split} too small given {len(eps)} episodes"
        val_ep_idxs = jax.random.choice(
            jax.random.PRNGKey(args.seed),
            len(eps),
            shape=(num_val_eps,),
            replace=False,
        )
        val_eps = [eps[i] for i in val_ep_idxs]
        eps = [ep for i, ep in enumerate(eps) if i not in val_ep_idxs]
    else:
        print(f"found {len(eps)} episodes, no validation set.")

    def _assemble_dataset(eps):
        """
        Assemble subtrajectory dataset from list of episodes.

        Subtrajectories have length args.trajectory_length,
        with args.dataset_stride stride across dataset.

        Subtrajectories never reset at intermediate steps, or timeout at
        any step (done flag corresponds to terminal only).
        """
        if args.trajectory_length > 1:
            # --- Concatenate episodes and find global episode start indices ---
            print("Assembling dataset", end="...")
            flat_done = jnp.concatenate([ep["done"] for ep in eps], axis=0)
            done_idxs = jnp.argwhere(flat_done).squeeze(axis=-1)
            if done_idxs[-1] == len(flat_done) - 1:
                done_idxs = done_idxs[:-1]
            init_idxs = jnp.concatenate([jnp.zeros(1), done_idxs + 1], axis=0)

            # --- Compute subtrajectory indices without intermediate episode resets ---
            any_done = jax.jit(partial(jnp.convolve, mode="valid"))(
                a=jnp.ones(args.trajectory_length - 1), v=flat_done[:-1]
            )
            valid_start_idxs = jnp.argwhere(any_done == 0).squeeze(axis=-1)

            # --- Compute subtrajecories ending with terminal or timeout ---
            flat_term = jnp.concatenate([ep["terminals"] for ep in eps], axis=0)
            term_idxs = jnp.argwhere(flat_term).squeeze(axis=-1)
            flat_timeout = jnp.concatenate([ep["timeouts"] for ep in eps], axis=0)
            timeout_idxs = jnp.argwhere(flat_timeout).squeeze(axis=-1)
            print(
                f"{len(term_idxs)} terminal, {len(timeout_idxs)} timeout flags found",
                end="...",
            )
            term_idxs -= args.trajectory_length - 1
            timeout_idxs -= args.trajectory_length - 1

            # --- Compute subtrajectory indices ---
            # Add strided subtrajectories
            start_idxs = set(valid_start_idxs[:: args.dataset_stride].tolist())
            # Add the start and end (final step terminal) of episodes
            start_idxs |= set(valid_start_idxs.tolist()) & set(term_idxs.tolist())
            start_idxs |= set(valid_start_idxs.tolist()) & set(init_idxs.tolist())
            # Remove subtrajectories ending in timeout
            start_idxs -= set(timeout_idxs.tolist())
            # Compute index array from list of start positions
            start_idxs = jnp.array(list(start_idxs), dtype=jnp.int32)
            subtraj_idxs = jax.jit(
                jax.vmap(lambda x: jnp.arange(args.trajectory_length) + x)
            )(start_idxs)
        else:
            # --- Remove timeout transitions ---
            flat_timeout = jnp.concatenate([ep["timeouts"] for ep in eps], axis=0)
            subtraj_idxs = jnp.argwhere(~flat_timeout).squeeze(axis=-1)

        # --- Construct subtrajectories from indices ---
        def _construct_tensor(data, add_singleton=False):
            # --- Construct Jax Numpy array from subtrajectory indices ---
            ret = jnp.concatenate(data, axis=0)
            ret = jnp.take(ret, subtraj_idxs, axis=0)
            if add_singleton:
                # Add singleton dimension
                return jnp.expand_dims(ret, axis=-1)
            return ret

        trajectories = Transition(
            obs=_construct_tensor([ep["observations"] for ep in eps]),
            action=_construct_tensor([ep["actions"] for ep in eps]),
            reward=_construct_tensor([ep["rewards"] for ep in eps], add_singleton=True),
            next_obs=_construct_tensor([ep["next_observations"] for ep in eps]),
            done=_construct_tensor([ep["terminals"] for ep in eps], add_singleton=True),
            value=None,
            log_prob=None,
            info=None,
        )
        print(f"done ({len(subtraj_idxs)} subtrajectories constructed).")
        print(f"Number of terminals: {jnp.sum(trajectories.done)}")
        assert ~jnp.any(
            trajectories.done[:, :-1]
        ), "Done flags in the middle of subtrajectory"
        return trajectories

    # --- Return assembled training and validation datasets ---
    return (
        _assemble_dataset(eps),
        _assemble_dataset(val_eps) if val_split > 0.0 else None,
    )


def _normalize_dataset(trajs):
    """Normalize observations, actions, rewards and done flags"""
    obs, obs_norm_mean, obs_norm_std = normalise_traj(trajs.obs)
    obs_stats = {"mean": obs_norm_mean, "std": obs_norm_std}
    next_obs = normalise_traj(trajs.next_obs, obs_stats)
    action, action_norm_mean, action_norm_std = normalise_traj(trajs.action)
    reward, reward_norm_mean, reward_norm_std = normalise_traj(trajs.reward)
    done, done_norm_mean, done_norm_std = normalise_traj(trajs.done)
    trajectory_norm_stats = {
        "obs": obs_stats,
        "action": {"mean": action_norm_mean, "std": action_norm_std},
        "reward": {"mean": reward_norm_mean, "std": reward_norm_std},
        "done": {"mean": done_norm_mean, "std": done_norm_std},
    }
    return (
        trajs._replace(
            obs=obs,
            action=action,
            reward=reward,
            done=done,
            next_obs=next_obs,
        ),
        trajectory_norm_stats,
    )


def _normalize_from_stats(trajs, stats):
    """Normalize observations, actions, rewards and done flags with given statistics"""
    return trajs._replace(
        obs=normalise_traj(trajs.obs, stats["obs"]),
        next_obs=normalise_traj(trajs.next_obs, stats["obs"]),
        action=normalise_traj(trajs.action, stats["action"]),
        reward=normalise_traj(trajs.reward, stats["reward"]),
        done=normalise_traj(trajs.done, stats["done"]),
    )


# offline_rollout
class DatasetRolloutGenerator:
    """Parent class for rollout generators that use a dataset"""

    def __init__(self, dataset, batch_size):
        self._dataset = dataset
        self._num_transitions = self._dataset.obs.shape[0]
        self._batch_size = batch_size

        def _get_batch(data, rng):
            permutation = jax.random.choice(
                rng,
                jnp.arange(self._num_transitions),
                shape=(self._batch_size,),
                replace=False,
            )
            # Sample transitions from dataset
            batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=0), data
            )
            # Reshape batch to conform with online rollout shape
            return jax.tree_util.tree_map(
                lambda x: jnp.reshape(x, (x.shape[0], 1, *x.shape[1:])), batch
            )

        self.batch_fn = jax.jit(_get_batch)

    def batch_rollout(self, rng):
        return self.batch_fn(self._dataset, rng)


class OfflineRolloutGenerator(DatasetRolloutGenerator):
    def __init__(
        self,
        args,
        obs_shape,
        action_dim,
        action_lims,
        num_env_steps,
        agent_apply_fn=None,
        batch_size=None,
    ):
        self.num_env_steps = num_env_steps
        self.agent_apply_fn = agent_apply_fn
        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.action_lims = action_lims

        args.trajectory_length = 1
        transitions = load_dataset(args, normalize=False)[0]
        # Flatten dataset
        transitions = jax.tree_map(lambda x: x.reshape((-1, x.shape[-1])), transitions)
        self.obs_stats = {
            "mean": transitions.obs.mean(axis=0),
            "std": transitions.obs.std(axis=0),
        }
        if batch_size is None:
            batch_size = args.batch_size
        super().__init__(transitions, batch_size)

    def set_apply_fn(self, agent_apply_fn):
        self.agent_apply_fn = agent_apply_fn


# rollout
def get_gym_env(dataset_name: str, num_env_workers: int):
    """Returns a Gym environment, matching the D4RL dataset configuration."""
    env = gym.vector.make(dataset_name, num_envs=num_env_workers)
    max_episode_steps = env.env_fns[0]().spec.max_episode_steps
    env = gym.wrappers.RecordEpisodeStatistics(env)
    has_dict_obs = isinstance(env.single_observation_space, gym.spaces.Dict)
    return env, has_dict_obs, max_episode_steps


class GymRolloutWrapper:
    def __init__(
        self,
        env_name: str,
        num_env_steps: Optional[int] = None,
        agent_apply_fn: Optional[callable] = None,
        num_env_workers: Optional[int] = None,
    ):
        self.env_name = env_name
        self.env, self.convert_dict_obs, max_episode_steps = get_gym_env(
            env_name, num_env_workers
        )
        self.agent_apply_fn = agent_apply_fn
        if num_env_steps is None:
            self.num_env_steps = max_episode_steps
        else:
            self.num_env_steps = num_env_steps

    def batch_reset(self, rng, num_env_workers):
        """Reset a single environment over a batch of seeds."""
        seeds = jax.random.split(rng, num_env_workers)
        reset_obs = self.env.reset(seed=[int(i[0]) for i in seeds])
        if self.convert_dict_obs:
            reset_obs = stack_dict_obs(reset_obs)
        return reset_obs

    def batch_rollout(self, rng, agent_state, last_obs):
        """Evaluate an agent on a single environment over a batch of seeds and environment states."""

        @jax.jit
        @jax.vmap
        def _policy_step(rng, obs):
            # --- Compute next action for a single state ---
            pi = self.agent_apply_fn(agent_state.params, obs)
            rng, _rng = jax.random.split(rng)
            action, log_prob = pi.sample_and_log_prob(seed=_rng)
            action = jnp.nan_to_num(action)
            # Sum action dimension log probabilities
            log_prob = log_prob.sum(axis=-1)
            return rng, action, log_prob

        transition_list = []
        num_env_workers = last_obs.shape[0]
        rng = jax.random.split(rng, num_env_workers)
        returned = [False for _ in range(num_env_workers)]
        for _ in range(self.num_env_steps):
            # --- Take step in environment ---
            rng, action, log_prob = _policy_step(rng, jnp.array(last_obs))
            obs, reward, done, info = self.env.step(onp.array(action))
            if self.convert_dict_obs:
                obs = stack_dict_obs(obs)

            # --- Track cumulative reward ---
            new_returned = []
            returned_episode_returns = []
            for worker_id, worker_returned in enumerate(returned):
                if "episode" in info[worker_id].keys() and not worker_returned:
                    returned_episode_returns.append(info[worker_id]["episode"]["r"])
                    new_returned.append(True)
                else:
                    returned_episode_returns.append(jnp.nan)
                    new_returned.append(worker_returned)
            returned = new_returned
            returned_episode_returns = jnp.array(returned_episode_returns)
            info = {
                "returned_episode_returns": returned_episode_returns,
                "returned_episode_scores": d4rl.get_normalized_score(
                    self.env_name, returned_episode_returns
                )
                * 100.0,
            }

            # --- Construct transition ---
            transition_list.append(
                Transition(
                    obs=last_obs,
                    action=action,
                    reward=jnp.expand_dims(reward, axis=-1),
                    done=jnp.expand_dims(done, axis=-1),
                    next_obs=obs,
                    log_prob=jnp.expand_dims(log_prob, axis=-1),
                    value=None,
                    info=info,
                )
            )
            last_obs = obs
        return tree_stack(transition_list)

    @property
    def obs_shape(self):
        """Get the shape of the observation."""
        if self.convert_dict_obs:
            return dict_obs_shape(self.env.single_observation_space)
        return self.env.single_observation_space.shape

    @property
    def action_dim(self):
        """Get the dimension of the action space."""
        return self.env.single_action_space.shape[0]

    @property
    def action_lims(self):
        """Get the action limits for the environment."""
        return (
            self.env.single_action_space.low[0],
            self.env.single_action_space.high[0],
        )

    def set_apply_fn(self, agent_apply_fn):
        self.agent_apply_fn = agent_apply_fn