import pickle
import random

from pathlib import Path

import numpy as np
import torch

from torch.utils.data import (
    Dataset,
    WeightedRandomSampler,
)


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum


class SingleTaskDataset(Dataset):
    def __init__(
        self,
        dataset_path: Path,
        context_size: int = 20,
        pct_traj: float = 1.0,
        normalize_inputs: bool = True,
        delayed: bool = False,
        max_ep_len: int = 200,
        scale: float = 1.0,
        add_one_hot: bool = False,
        one_hot_idx: int = 0,
        one_hot_len: int = 1,
    ):
        super().__init__()

        self.dataset_path = dataset_path
        self.context_size = context_size
        self.pct_traj = pct_traj
        self.normalize_inputs = normalize_inputs
        self.delayed = delayed
        self.max_ep_len = max_ep_len
        self.scale = scale
        self.add_one_hot = add_one_hot
        self.one_hot_idx = one_hot_idx
        self.one_hot_len = one_hot_len

        with open(dataset_path, "rb") as f:
            self.trajectories = pickle.load(f)

        # used for input normalization
        self.state_mean, self.state_std = self.calculate_state_stats(self.trajectories)

        # only train on top pct_traj trajectories (for %BC experiment)
        self.trajectories = self.top_pct_traj(self.trajectories, pct_traj=self.pct_traj, delayed=self.delayed)

        if self.add_one_hot:
            self.trajectories_to_one_hot()
            self.state_mean = np.concatenate([self.state_mean, [0] * self.one_hot_len])
            self.state_std = np.concatenate([self.state_std, [1] * self.one_hot_len])

        self.state_dim = self.trajectories[0]["observations"].shape[-1]
        self.act_dim = self.trajectories[0]["actions"].shape[-1]

        self.traj_lens = self.calculate_traj_lens(self.trajectories)
        self.returns = self.calculate_returns(self.trajectories, delayed=self.delayed)
        self.num_timesteps = sum(self.traj_lens)

        print("=" * 50)
        print(f"Dataset: {dataset_path.stem}")
        print(f"{len(self.traj_lens)} trajectories, {self.num_timesteps} timesteps found")
        print(f"Average return: {np.mean(self.returns):.2f}, std: {np.std(self.returns):.2f}")
        print(f"Max return: {np.max(self.returns):.2f}, min: {np.min(self.returns):.2f}")
        print("=" * 50)

    def create_weighted_sampler(self):
        samples_weight = torch.from_numpy(self.traj_lens / sum(self.traj_lens))
        weighed_sampler = WeightedRandomSampler(samples_weight.float(), len(samples_weight))
        return weighed_sampler

    def trajectories_to_one_hot(self):
        one_hot = np.zeros(self.one_hot_len)
        one_hot[self.one_hot_idx] = 1
        for i in range(len(self.trajectories)):
            states = self.trajectories[i]["observations"]
            self.trajectories[i]["observations"] = np.concatenate(
                [states, one_hot[np.newaxis, :].repeat(len(states), axis=0)], axis=1
            )

    @staticmethod
    def top_pct_traj(trajectories, *, pct_traj: float = 1.0, delayed: bool = False):
        traj_lens = SingleTaskDataset.calculate_traj_lens(trajectories)
        returns = SingleTaskDataset.calculate_returns(trajectories, delayed=delayed)

        num_timesteps = sum(traj_lens)
        num_timesteps = max(int(pct_traj * num_timesteps), 1)

        sorted_inds = np.argsort(returns)  # lowest to highest
        num_trajectories = 1
        timesteps = traj_lens[sorted_inds[-1]]
        ind = len(trajectories) - 2
        while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
            timesteps += traj_lens[sorted_inds[ind]]
            num_trajectories += 1
            ind -= 1
        sorted_inds = sorted_inds[-num_trajectories:]

        trajectories = np.array(trajectories)[sorted_inds].tolist()
        return trajectories

    @staticmethod
    def calculate_returns(trajectories, *, delayed: bool = False):
        returns = []
        for path in trajectories:
            if delayed:  # delayed: all rewards moved to end of trajectory
                path["rewards"][-1] = path["rewards"].sum()
                path["rewards"][:-1] = 0.0
            returns.append(path["rewards"].sum())
        return np.array(returns)

    @staticmethod
    def calculate_traj_lens(trajectories):
        return np.array([len(path["observations"]) for path in trajectories])

    @staticmethod
    def calculate_state_stats(trajectories):
        states = [path["observations"] for path in trajectories]
        states = np.concatenate(states, axis=0)
        state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        return state_mean, state_std

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        traj = self.trajectories[idx]

        traj_states = traj["observations"]
        traj_actions = traj["actions"]
        traj_rewards = traj["rewards"]
        if "terminals" in traj:
            traj_dones = traj["terminals"]
        else:
            traj_dones = traj["dones"]
        traj_mask = np.ones(traj_states.shape[0])

        # turned off padding from start, as it breaks DT
        # traj_states = np.concatenate([np.zeros((self.context_size, self.state_dim)), traj_states], axis=0)
        # traj_actions = np.concatenate([np.ones((self.context_size, self.act_dim)) * -10.0, traj_actions], axis=0)
        # traj_rewards = np.concatenate([np.zeros(self.context_size), traj_rewards], axis=0)
        # traj_dones = np.concatenate([np.ones(self.context_size) * 2, traj_dones], axis=0)
        # traj_mask = np.concatenate([np.zeros(self.context_size), traj_mask], axis=0)

        si = random.randint(0, traj_rewards.shape[0] - 1)

        states = traj_states[si : si + self.context_size]
        actions = traj_actions[si : si + self.context_size]
        rewards = traj_rewards[si : si + self.context_size].reshape(-1, 1)
        dones = traj_dones[si : si + self.context_size]
        mask = traj_mask[si : si + self.context_size]

        tlen = states.shape[0]

        timesteps = np.arange(si, si + tlen)
        timesteps[timesteps >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
        rtg = discount_cumsum(traj_rewards[si:], gamma=1.0)[: tlen + 1].reshape(-1, 1)
        if rtg.shape[0] <= tlen:
            rtg = np.concatenate([rtg, np.zeros((1, 1))], axis=0)

        states = np.concatenate([states, np.zeros((self.context_size - tlen, self.state_dim))], axis=0)
        if self.normalize_inputs:
            states = (states - self.state_mean) / self.state_std

        # actions are padded with -10 and rtg with 2 to enable easier debugging, we mask them either way
        actions = np.concatenate([actions, np.ones((self.context_size - tlen, self.act_dim)) * -10.0], axis=0)
        rewards = np.concatenate([rewards, np.zeros((self.context_size - tlen, 1))], axis=0)
        dones = np.concatenate([dones, np.ones((self.context_size - tlen)) * 2], axis=0)
        rtg = np.concatenate([rtg, np.zeros((self.context_size - tlen, 1))], axis=0) / self.scale
        timesteps = np.concatenate([timesteps, np.zeros((self.context_size - tlen))], axis=0)
        mask = np.concatenate([mask, np.zeros((self.context_size - tlen))], axis=0)

        return (
            states.astype(np.float32),
            actions.astype(np.float32),
            rewards.astype(np.float32),
            dones.astype(np.int64),
            rtg.astype(np.float32),
            timesteps.astype(np.int64),
            mask,
        )
