from collections import defaultdict
from typing import NamedTuple

import numpy as np


def rlkit_buffer_to_macaw_format(buffer, discount_factor, path_length, **kwargs):
    size = buffer._top
    end_indices = compute_end_indices(buffer._terminals, size, path_length)
    data = {
        'obs': buffer._observations[:size],
        'actions': buffer._actions[:size],
        'rewards': buffer._rewards[:size],
        'next_obs': buffer._next_obs[:size],
        'terminals': buffer._terminals[:size],
        'discount_factor': discount_factor,
        'end_indices': end_indices,
    }
    for k, v in buffer._env_infos.items():
        data[k] = v[:size]
    add_trajectory_data_to_buffer(buffer, data, discount_factor, path_length, **kwargs)
    return data


def rlkit_buffer_to_borel_format(
        buffer,
        discount_factor,
        path_length,
        meta_episode_len=600,
        add_done_info=True,
        start_idx=0,
        end_idx=None,
):
    obs, actions, rewards, next_obs, terminals = [], [], [], [], []
    env_infos = defaultdict(list)
    traj_iter = yield_trajectories(buffer, path_length, start_idx, end_idx)
    done = False
    while not done:
        # group trajectories into meta-episodes
        meta_obs, meta_actions, meta_rewards, meta_next_obs, meta_terminals = [], [], [], [], []
        meta_env_infos = defaultdict(list)
        current_meta_episode_len = 0
        env_info_keys = list(buffer._env_infos.keys())
        while current_meta_episode_len < meta_episode_len:
            try:
                traj = next(traj_iter)
            except StopIteration:
                done = True
                break
            if add_done_info:
                current_obs = np.array([
                    np.concatenate([e.state, e.done], axis=0) for e in traj
                ])
                current_next_obs = np.array([
                    np.concatenate([e.next_state, e.done], axis=0) for e in traj
                ])
            else:
                current_obs = np.array([e.state for e in traj])
                current_next_obs = np.array([e.next_state for e in traj])
            current_action = np.array([e.action for e in traj])
            current_rewards = np.array([e.reward for e in traj])
            current_terms = np.array([np.zeros_like(e.done) for e in traj])

            current_env_infos = {
                k: np.array([e.env_info[k] for e in traj])
                for k in env_info_keys
            }

            meta_obs.append(current_obs)
            meta_actions.append(current_action)
            meta_rewards.append(current_rewards)
            meta_next_obs.append(current_next_obs)
            meta_terminals.append(current_terms)
            for k in env_info_keys:
                meta_env_infos[k].append(current_env_infos[k])
            current_meta_episode_len += len(current_obs)

        if len(meta_obs) == 0:
            break

        meta_obs = np.concatenate(meta_obs, axis=0)
        meta_next_obs = np.concatenate(meta_next_obs, axis=0)
        meta_actions = np.concatenate(meta_actions, axis=0)
        meta_rewards = np.concatenate(meta_rewards, axis=0)
        meta_terminals = np.concatenate(meta_terminals, axis=0)
        for k in env_info_keys:
            meta_env_infos[k] = np.concatenate(meta_env_infos[k], axis=0)

        if len(meta_obs) < meta_episode_len:
            break
        obs.append(meta_obs[:meta_episode_len])
        actions.append(meta_actions[:meta_episode_len])
        next_obs.append(meta_next_obs[:meta_episode_len])
        rewards.append(meta_rewards[:meta_episode_len])
        terminals.append(meta_terminals[:meta_episode_len])
        for k in buffer._env_infos.keys():
            env_infos[k].append(meta_env_infos[k][:meta_episode_len])
    data = {
        'obs': np.array(obs).transpose(1, 0, 2),
        'actions': np.array(actions).transpose(1, 0, 2),
        'rewards': np.array(rewards).transpose(1, 0, 2),
        'next_obs': np.array(next_obs).transpose(1, 0, 2),
        'terminals': np.array(terminals).transpose(1, 0, 2),
        'discount_factor': discount_factor,
        'trajectory_len': path_length,
    }
    for k in buffer._env_infos.keys():
        data[k] = np.array(env_infos[k]).transpose(1, 0, 2)

    return data


class Experience(NamedTuple):
    state: np.ndarray
    action: np.ndarray
    next_state: np.ndarray
    reward: float
    done: bool
    env_info: dict


def yield_trajectories(buffer, path_length, start_idx, end_idx):
    start_idx, end_idx = clean_start_end_idx(buffer._size, start_idx, end_idx)
    end_indices = compute_end_indices(buffer._terminals, buffer._top, path_length)
    current_traj = []
    for i in range(start_idx, end_idx):
        experience = Experience(
            state=buffer._observations[i],
            action=buffer._actions[i],
            next_state=buffer._next_obs[i],
            reward=buffer._rewards[i],
            done=buffer._terminals[i],
            env_info={k: infos[i] for k, infos in buffer._env_infos.items()}
        )
        current_traj.append(experience)
        if i in end_indices:
            yield current_traj
            current_traj = []


def clean_start_end_idx(num_steps, start_idx, end_idx):
    if end_idx is None:
        end_idx = num_steps
        if num_steps == start_idx:
            raise ValueError("nothing to copy!")
    if end_idx < 0 or end_idx <= start_idx:
        raise ValueError("end_idx must be larger than start_idx")
    if end_idx > num_steps:
        # raise IndexError("Indexing into uninitialized region.")
        print("Would have index into uninitialized region. Setting end_idx to max end_idx")
        end_idx = num_steps
    if start_idx < 0:
        start_idx = end_idx + start_idx
        if start_idx < 0:
            raise ValueError("start_idx is negative but end_idx is too small")
    return start_idx, end_idx


def add_trajectory_data_to_buffer(buffer, data, discount_factor, path_length, **kwargs):
    write_loc = 0
    all_terminal_obs = np.zeros_like(data['obs'])
    all_terminal_discounts = np.zeros_like(data['terminals'], dtype=np.float64)
    all_mc_rewards = np.zeros_like(data['rewards'])
    for trajectory in yield_trajectories(buffer, path_length, **kwargs):
        mc_reward = 0
        terminal_obs = None
        terminal_factor = 1
        for idx, experience in enumerate(trajectory[::-1]):
            if terminal_obs is None:
                terminal_obs = experience.next_state

            all_terminal_obs[write_loc] = terminal_obs
            terminal_factor *= discount_factor
            all_terminal_discounts[write_loc] = terminal_factor
            mc_reward = experience.reward + discount_factor * mc_reward
            all_mc_rewards[write_loc] = mc_reward
            write_loc += 1

    data['terminal_obs'] = all_terminal_obs
    data['terminal_discounts'] = all_terminal_discounts
    data['mc_rewards'] = all_mc_rewards


def compute_end_indices(terminals, size, path_length):
    """Return a list of end indices. A end index is an index where a new
    episode ends."""
    traj_start_i = 0
    current_i = 0
    end_indices = []
    while current_i < size:
        if (
            current_i - traj_start_i + 1 == path_length
            or terminals[current_i]
        ):
            end_indices.append(current_i)
            traj_start_i = current_i + 1
        current_i += 1
    return end_indices