"""Utilities used throughout the codebase."""

from __future__ import annotations
from cProfile import run
from torch.utils.data import IterableDataset, TensorDataset
from collections import namedtuple

import glob
import json
import os
import random
import shutil
from typing import Dict, List, Optional, Tuple, Union

from d4rl import offline_env
from gym import spaces
import numpy as np
import torch
import wandb
from wandb.sdk.wandb_run import Run
import gym
from gym.spaces import Dict, Box, Discrete
from sklearn.cluster import MiniBatchKMeans
from tqdm import tqdm

from rvs import step, train


def configure_gpu(use_gpu: bool, which_gpu: int) -> torch.device:
    """Set the GPU to be used for training."""
    if use_gpu:
        device = torch.device("cuda")
        # Only occupy one GPU, as in https://stackoverflow.com/questions/37893755/
        # tensorflow-set-cuda-visible-devices-within-jupyter
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
        # os.environ["CUDA_VISIBLE_DEVICES"] = str(which_gpu) # This was causing issues
    else:
        device = torch.device("cpu")
        os.environ["CUDA_VISIBLE_DEVICES"] = ""

    return device


def set_seed(seed: Optional[int]) -> None:
    """Set the numpy, random, and torch random seeds."""
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)


def extract_traj_markers(
    dataset: Dict[str, np.ndarray],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Given a D4RL dataset, return starts, ends, and lengths of trajectories."""
    dones = np.logical_or(dataset["terminals"], dataset["timeouts"])
    return extract_done_markers(dones)


def extract_done_markers(
    dones: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Given a per-timestep dones vector, return starts, ends, and lengths of trajs."""
    (ends,) = np.where(dones)
    starts = np.concatenate(([0], ends[:-1] + 1))
    lengths = ends - starts + 1

    return starts, ends, lengths


def collect_timestep_indices(
    dones: np.ndarray,
    trajectory_indices: Union[List[int], np.ndarray],
) -> np.ndarray:
    """Find all timestep indices within the given trajectory indices."""
    starts, ends, _ = extract_done_markers(dones)
    starts = starts[trajectory_indices]
    ends = ends[trajectory_indices]

    timestep_indices = []
    for start, end in zip(starts, ends):
        timesteps = np.arange(start, end + 1)
        timestep_indices.append(timesteps)
    timestep_indices = (
        np.concatenate(timestep_indices) if len(
            timestep_indices) > 0 else np.array([])
    )

    return timestep_indices


def concatenate_two_boxes(box_a: spaces.Box, box_b: spaces.Box) -> spaces.Box:
    """Concatenate two Box spaces into one Box space."""
    if not isinstance(box_a, spaces.Box) or not isinstance(box_b, spaces.Box):
        raise ValueError("This method will only concatenate Box spaces")

    lows = np.concatenate([box_a.low, box_b.low])
    highs = np.concatenate([box_a.high, box_b.high])
    dtype = np.result_type(*[box_a.dtype, box_b.dtype])

    return spaces.Box(low=lows, high=highs, dtype=dtype)


def duplicate_observation_space(observation_space: spaces.Box) -> spaces.Box:
    """Double the observation spaces by concatenating it with itself."""
    if not isinstance(observation_space, spaces.Box):
        raise ValueError(
            "This method will only duplicate Box observation_spaces")

    return concatenate_two_boxes(observation_space, observation_space)


def flatten_observation_goal_spaces(observation_space: spaces.Dict) -> spaces.Box:
    """Create a Box space out of the observation and desired_goal of the Dict space."""
    if not isinstance(observation_space, spaces.Dict):
        raise ValueError(
            "This method will only flatten Dict observation_spaces")

    return concatenate_two_boxes(
        observation_space["observation"],
        observation_space["desired_goal"],
    )


def create_observation_goal_space(
    observation_space: Union[spaces.Box, spaces.Dict],
) -> spaces.Box:
    """Take the observation space and produce a space with observations and goals."""
    if isinstance(observation_space, spaces.Dict):
        return flatten_observation_goal_spaces(observation_space)
    else:
        return duplicate_observation_space(observation_space)


def add_scalar_to_space(
    observation_space: Union[spaces.Box, spaces.Dict],
) -> spaces.Box:
    """Add one scalar to the observation space."""
    if isinstance(observation_space, spaces.Dict):
        observation_space = observation_space["observation"]
    if not isinstance(observation_space, spaces.Box):
        raise ValueError(
            "This method can only add reward to a Box observation_space")

    lows = np.concatenate([observation_space.low, [-np.inf]])
    highs = np.concatenate([observation_space.high, [np.inf]])
    return spaces.Box(low=lows, high=highs, dtype=observation_space.dtype)


def resolve_out_directory(run_id: str, entity: str, project: str, use_cached: bool = False) -> Tuple[str, Run]:
    """Download wandb run and return its local output directory."""
    # get the wandb run from the api
    api = wandb.Api()
    api_run = api.run(f"{entity}/{project}/{run_id}")

    if not os.path.exists(os.path.join('./wandb', run_id)):
        os.makedirs(os.path.join('./wandb', run_id))

    # resume the wandb run
    wandb.init(
        entity=entity,
        project=project,
        id=run_id,
        resume="must",
    )
    wandb_run = wandb.run

    run_folders = sorted(glob.glob(f"./wandb/run*{run_id}"))
    if use_cached and len(run_folders) > 1:
        shutil.rmtree(os.path.join(run_folders[-1], 'files'))
        shutil.copytree(os.path.join(
            run_folders[-2], 'files'), os.path.join(run_folders[-1], 'files'))
        print("Copying cached files from previous run")
    else:
        print("Downloading files from wandb...")
        for file in api_run.files():
            wandb_run.restore(file.name)
        print("Successfully downloaded files")

    return wandb_run.dir, wandb_run


def sorted_glob(*args, **kwargs) -> List[str]:
    """A sorted version of glob, to ensure determinism and prevent bugs."""
    return sorted(glob.glob(*args, **kwargs))


def parse_val_loss(filename: str) -> float:
    """Parse val_loss from the checkpoint filename."""
    start = filename.index("val_loss=") + len("val_loss=")
    try:
        end = filename.index("-v1.ckpt")
    except ValueError:
        end = filename.index(".ckpt")
    val_loss = float(filename[start:end])
    return val_loss


def get_best_val_checkpoint(
    checkpoint_dir,
) -> Union[Tuple[str, np.float64], Tuple[None, None]]:
    """Find the checkpoint with the best val_loss in the checkpoint directory."""
    checkpoints = sorted_glob(os.path.join(checkpoint_dir, "*val*.ckpt"))
    if len(checkpoints) == 0:
        return None, None
    losses = np.array([parse_val_loss(checkpoint)
                      for checkpoint in checkpoints])
    argmin = np.argmin(losses)
    return checkpoints[argmin], losses[argmin]


def get_checkpoints(
    out_directory: str,
    last_checkpoints_too: bool = False,
) -> Tuple[List[str], List[Dict[str, Union[int, float, str]]]]:
    """Gather checkpoint filenames and attribute dictionaries from output directory."""
    checkpoints = []
    attribute_dicts = []

    checkpoint_dir = os.path.join(out_directory, train.checkpoint_dir)
    val_checkpoint, val_loss = get_best_val_checkpoint(checkpoint_dir)
    if val_checkpoint is not None:
        checkpoints.append(val_checkpoint)
        attribute_dicts.append(
            {"Checkpoint": "Validation", "val_loss": val_loss},
        )
    else:
        last_checkpoints_too = True
    if last_checkpoints_too:
        last_checkpoint = sorted_glob(
            os.path.join(checkpoint_dir, "last.ckpt"))[-1]
        checkpoints.append(last_checkpoint)
        attribute_dicts.append({"Checkpoint": "Last"})

    return checkpoints, attribute_dicts


def get_parameters(out_directory: str) -> Dict[str, Union[int, float, str, bool]]:
    """Load parameters from the output directory."""
    args_file = os.path.join(out_directory, train.args_filename)
    with open(args_file, "r") as f:
        parameters = json.load(f)

    return parameters


def load_experiment(
    out_directory: str,
    last_checkpoints_too: bool = False,
) -> Tuple[
    List[str],
    List[Dict[str, Union[int, float, str]]],
    Dict[str, Union[int, float, str, bool]],
    Union[step.GCSLToGym, offline_env.OfflineEnv],
]:
    """Load experiment from the output directory.

    Returns paths to model checkpoints, their associated attribute dictionaries, the
    parameters of the experimental run, and the environment.
    """
    checkpoints, attribute_dicts = get_checkpoints(
        out_directory,
        last_checkpoints_too=last_checkpoints_too,
    )
    parameters = get_parameters(out_directory)
    parameters["unconditional_policy"] = parameters.get(
        "unconditional_policy", False)
    parameters["reward_conditioning"] = parameters.get(
        "reward_conditioning", False)
    parameters["cumulative_reward_to_go"] = parameters.get(
        "cumulative_reward_to_go",
        False,
    )
    parameters["seed"] = parameters.get("seed", None)

    set_seed(parameters["seed"])
    env = step.create_env(
        parameters["env_name"],
        parameters["max_episode_steps"],
        parameters["discretize"],
        parameters.get("discrete_clusters"),
        seed=parameters["seed"],
    )

    env_lambda = lambda: step.create_env(
        parameters["env_name"],
        parameters["max_episode_steps"],
        parameters["discretize"],
        parameters.get("discrete_clusters"),
        seed=parameters["seed"],
    )

    return checkpoints, attribute_dicts, parameters, env, env_lambda

# Utilities for using our bootstrapping code with the RvS codebase.


Trajectory = namedtuple(
    "Trajectory", ["obs", "actions", "continuous_actions", "rewards", "infos", "policy_infos"])

# Convert D4RL to trajectory list


def get_trajs(dataset):

    N = dataset['rewards'].shape[0]
    obs_ = []
    actions_ = []
    continuous_actions_ = []
    rewards_ = []
    infos_ = []
    policy_infos_ = []

    trajs = []
    for i in range(N):
        obs_.append(dataset['observations'][i])
        actions_.append(dataset['actions'][i])
        if 'continuous_actions' in dataset.keys():
            continuous_actions_.append(dataset['continuous_actions'][i])
        else:
            continuous_actions_.append(dataset['actions'][i])
        rewards_.append(dataset['rewards'][i])
        infos_.append(None)
        policy_infos_.append(None)
        done = bool(dataset['terminals'][i])
        timeout = bool(dataset['timeouts'][i])

        if done or timeout:
            trajs.append(Trajectory(obs=obs_,
                                    actions=actions_,
                                    continuous_actions=continuous_actions_,
                                    rewards=rewards_,
                                    infos=infos_,
                                    policy_infos=policy_infos_))
            t = 0
            obs_ = []
            actions_ = []
            continuous_actions_ = []
            rewards_ = []
            infos_ = []
            policy_infos_ = []
    if len(obs_) > 0:
        trajs.append(Trajectory(obs=obs_,
                                actions=actions_,
                                continuous_actions=continuous_actions_,
                                rewards=rewards_,
                                infos=infos_,
                                policy_infos=policy_infos_))
    return trajs


class SegmentDataset(IterableDataset):

    rand: np.random.Generator

    def __init__(self, trajs, label_fn, label_to_sample=None):
        self.trajs = trajs
        self.traj_labels = [[label_fn(traj)] for traj in self.trajs]
        self.epoch_len = 1e6
        self.combine_goal = False
        self.append_time = False
        self.seed = None
        self.label_to_sample = label_to_sample
        self.continuous_actions = False

        self.cache_tensors()

    def cache_tensors(self):
        self.tensor_traj_obs = [torch.from_numpy(
            np.array(traj.obs)).float() for traj in self.trajs]
        self.tensor_traj_actions = [torch.from_numpy(
            np.array(traj.actions)).float() for traj in self.trajs]
        self.tensor_traj_continuous_actions = [torch.from_numpy(
            np.array(traj.continuous_actions)).float() for traj in self.trajs]
        self.tensor_traj_labels = [torch.from_numpy(
            np.array(traj_labels)).float() for traj_labels in self.traj_labels]

        self.traj_idxs = []
        self.t_idxs = []
        for i in range(len(self.trajs)):
            self.traj_idxs.extend([i] * len(self.trajs[i].rewards))
            self.t_idxs.extend(range(len(self.trajs[i].rewards)))

    def segment_generator(self, epoch_len):
        for _ in range(epoch_len):
            data_idx = self.rand.integers(len(self.traj_idxs))
            traj_idx = self.traj_idxs[data_idx]
            t_idx = self.t_idxs[data_idx]
            lab_idx = self.label_to_sample or self.rand.integers(
                len(self.traj_labels[traj_idx]))

            o1 = self.tensor_traj_obs[traj_idx][t_idx]
            if self.continuous_actions:
                a1 = self.tensor_traj_continuous_actions[traj_idx][t_idx]
            else:
                a1 = self.tensor_traj_actions[traj_idx][t_idx]
            l1 = self.tensor_traj_labels[traj_idx][lab_idx][t_idx]

            if self.append_time:
                # Represent time using a sin/cos pair
                t = t_idx / len(self.trajs[traj_idx].obs)
                o1 = torch.cat(
                    [o1, torch.tensor([np.sin(2 * np.pi * t), np.cos(2 * np.pi * t)])]).float()

            if self.combine_goal:
                o1 = torch.cat([o1, l1.view(1)], dim=0)
                yield o1, a1
            else:
                yield o1, a1, l1

    def __len__(self):
        return int(self.epoch_len)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            seed = self.seed + worker_info.id
        else:
            seed = self.seed
        self.rand = np.random.default_rng(seed)
        if worker_info is None:  # single-process data loading, return the full iterator
            gen = self.segment_generator(int(self.epoch_len))
        else:  # in a worker process
            # split workload
            per_worker_time_steps = int(
                self.epoch_len / float(worker_info.num_workers))
            gen = self.segment_generator(per_worker_time_steps)
        return gen

    def mean_std(self):
        obs = np.concatenate([np.array(traj.obs)
                             for traj in self.trajs], axis=0)
        obs_mean, obs_std = obs.mean(axis=0), obs.std(axis=0) + 1e-6
        # each labels is a (k, timesteps) array, so we need to flatten it
        labels = np.concatenate([np.array(labels).flatten()
                                for labels in self.traj_labels], axis=0)
        labels_mean, labels_std = labels.mean(axis=0).reshape(
            1), labels.std(axis=0).reshape(1) + 1e-6

        mean = np.concatenate([obs_mean, labels_mean])
        std = np.concatenate([obs_std, labels_std])

        return torch.tensor(mean), torch.tensor(std)

    def mean_std_acts(self):
        if self.continuous_actions:
            actions = np.concatenate([np.array(traj.continuous_actions)
                                      for traj in self.trajs], axis=0)
        else:
            actions = np.concatenate([np.array(traj.actions)
                                      for traj in self.trajs], axis=0)
        actions_mean, actions_std = actions.mean(
            axis=0), actions.std(axis=0) + 1e-6

        return torch.tensor(actions_mean), torch.tensor(actions_std)

    def convert_to_tensor_dataset(self):
        # Only works if we're sampling a single label
        assert self.label_to_sample is not None
        obs = np.concatenate([np.array(traj.obs)
                              for traj in self.trajs], axis=0)
        actions = np.concatenate([np.array(traj.actions)
                                  for traj in self.trajs], axis=0)
        labels = np.concatenate([np.array(labels[self.label_to_sample])
                                 for labels in self.traj_labels], axis=0)
        # Convert to tensors
        obs = torch.from_numpy(obs).float()
        actions = torch.from_numpy(actions).float()
        labels = torch.from_numpy(labels).float()

        if self.combine_goal:
            obs = torch.cat([obs, labels.view(-1, 1)], dim=1)
            return TensorDataset(obs, actions)

        return TensorDataset(obs, actions, labels)


class FastTensorDataLoader:
    """
    A DataLoader-like object for a set of tensors that can be much faster than
    TensorDataset + DataLoader because dataloader grabs individual indices of
    the dataset and calls cat (slow).
    Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
    """

    def __init__(self, tensor_dataset, batch_size=32, shuffle=True, device='cpu'):
        """
        Initialize a FastTensorDataLoader.

        :param *tensors: tensors to store. Must have the same length @ dim 0.
        :param batch_size: batch size to load.
        :param shuffle: if True, shuffle the data *in-place* whenever an
            iterator is created out of this object.

        :returns: A FastTensorDataLoader.
        """
        # assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
        self.tensors = tensor_dataset.tensors

        # Move the tensors to the right device
        self.tensors = [t.to(device) for t in self.tensors]

        self.dataset_len = self.tensors[0].shape[0]
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.epoch_len = int(1e6)

        # Calculate # batches
        n_batches, remainder = divmod(self.epoch_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        if self.shuffle:
            # self.r = np.random.permutation(self.dataset_len)
            # Get epoch len random indices (sample w replacement)
            self.r = np.random.randint(
                0, self.dataset_len, size=self.epoch_len)

            # r = torch.randperm(self.dataset_len)
            # self.tensors = [t[r] for t in self.tensors]
        else:
            raise NotImplementedError
            # self.r = np.arange(self.dataset_len)
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.epoch_len:
            raise StopIteration
        # batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
        batch = tuple(t[self.r[self.i:self.i + self.batch_size]]
                      for t in self.tensors)
        # Use self.r to shuffle t
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches


def tensor_dataset_from_iterable_dataset(iterable_dataset, n_epochs=1):
    """Convert an IterableDataset to a TensorDataset."""
    data = []
    for _ in range(n_epochs):
        for x in tqdm(iterable_dataset):
            data.append(x)
    data = list(zip(*data))
    return TensorDataset(*data)


def return_labels(traj, discount=1.0, initial_ret=(lambda _: 0)):
    rewards = traj.rewards
    returns = []
    ret = initial_ret(traj)
    for reward in reversed(rewards):
        ret *= discount
        ret += float(reward)
        returns.append(ret)
    returns = list(reversed(returns))
    return returns


class DiscretizeWrapper(gym.ActionWrapper):
    """Discretize a continuous action space into a discrete one.
    We assume the env is a D4RL env with a dataset. We run k-means on the
    dataset to find k clusters of actions, and then we use the cluster centers
    as the discrete actions.
    """

    def __init__(self, env, k=128):
        super().__init__(env)
        self.k = k
        dataset = self.env.get_dataset()
        actions = dataset['actions']
        print(f"Running k-means on {actions.shape[0]} actions...")
        self.kmeans = MiniBatchKMeans(
            n_clusters=k, random_state=0, n_init=3).fit(actions)
        self.action_space = gym.spaces.Discrete(k)

    def action(self, action):
        return self.kmeans.cluster_centers_[action]

    def reverse_action(self, actions):
        return self.kmeans.predict(actions)


def dataset_traj_segments(dataset):
    """Create a list of dictionaries, one for each trajectory segment."""
    segs = []
    start_t = 0
    t = 0
    while t < dataset['rewards'].shape[0]:
        done = dataset['terminals'][t] or dataset['timeouts'][t]
        if done:
            seg = {}
            for key in dataset:
                seg[key] = dataset[key][start_t:t + 1]
            segs.append(seg)
            start_t = t + 1
        t += 1
    return segs


def concatenate_datasets(datasets):
    """Concatenate a list of datasets."""
    new_dataset = {}
    for key in datasets[0]:
        new_dataset[key] = np.concatenate(
            [dataset[key] for dataset in datasets])
    return new_dataset


def slice_traj_segment(seg, start_idx, end_idx):
    """Slice a trajectory segment similar to python list slicing.
    We will have to construct a new trajectory segment object.
    """
    new_seg = {}
    for key in seg:
        new_seg[key] = seg[key][start_idx:end_idx]
    return new_seg


def slice_trajs(dataset, overlap=5, min_len=50):
    seg_list = []
    traj_segments = dataset_traj_segments(dataset)
    for seg in traj_segments:
        traj_len = seg['rewards'].shape[0]
        if traj_len > min_len:
            mid = traj_len // 2
            begin_seg = slice_traj_segment(seg, 0, mid + overlap)
            # Add timeout flag to the end of the segment
            begin_seg['timeouts'][-1] = True
            seg_list.append(begin_seg)
            seg_list.append(slice_traj_segment(seg, mid, traj_len))
            # Double check that there is no reward in the first segment
            assert np.sum(begin_seg['rewards']) == 0
            # Double check that if there was a reward at the last timestep of
            # the original trajectory, it is in the second segment
            assert seg['rewards'][-1] == 0 or seg_list[-1]['rewards'][-1] > 0
        else:
            seg_list.append(seg)
    return concatenate_datasets(seg_list)


def preprocess_dataset(dataset, method='none'):
    if method == 'none':
        return dataset
    elif method == 'slice_trajs':
        return slice_trajs(slice_trajs(slice_trajs(dataset)))
    else:
        raise ValueError(f"Unknown preprocessing method {method}")
