from collections import namedtuple
import numpy as np
import torch
import pdb
import joblib

# from .preprocessing import get_preprocess_fn
# from .d4rl import load_environment, sequence_dataset
# from .normalization import DatasetNormalizer
# from .buffer import ReplayBuffer

Batch = namedtuple('Batch', 'trajectories conditions')
ValueBatch = namedtuple('ValueBatch', 'trajectories conditions values')

def normalize(x, min_value, max_value):
    # [ 0, 1 ]
    x = (x - min_value) / (max_value - min_value)
    # [ -1, 1 ]
    x = 2 * x - 1
    return x


def normalize_obs(obs):
    obs[:, 0] = normalize(obs[:, 0], -2, 2)
    obs[:, 1] = normalize(obs[:, 1], -1, 1)
    obs[:, 2] = normalize(obs[:, 2], -5, 30)
    obs[:, 3] = normalize(obs[:, 3], 0, 1)
    return obs


class CarlaDataset(torch.utils.data.Dataset):
    def __init__(self, records, horizon=64):
        self.horizon = horizon

        self.obs = []
        self.action = []
        self.reward = []
        for record_file in records:
            print('loading', record_file)
            record = joblib.load(record_file)
            for trajectory in record:
                obs = []
                action = []
                reward = []
                for timestep in trajectory[:-1]:
                    obs.append(timestep['obs'])
                    action.append(timestep['act'])
                    reward.append(timestep['rew'])
                self.obs.append(normalize_obs(np.stack(obs)).astype(np.float32))
                self.action.append(np.stack(action).astype(np.float32))
                reward = np.stack(reward).astype(np.float32)
                self.reward.append(reward.reshape((len(reward), -1)))

        self.observation_dim = self.obs[0].shape[-1]
        self.action_dim = self.action[0].shape[-1]
        self.indices = self.make_indices(self.horizon)

    def make_indices(self, horizon):
        '''
            makes indices for sampling from dataset;
            each index maps to a datapoint
        '''
        indices = []
        for i in range(len(self.obs)):
            trajectory_len = self.obs[i].shape[0]
            for start in range(trajectory_len - horizon + 1):
                end = start + horizon
                indices.append((i, start, end))
        indices = np.array(indices)
        return indices

    def get_conditions(self, observations):
        '''
            condition on current observation for planning
        '''
        return {0: observations[0]}

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx, eps=1e-4):
        trajectory_id, start, end = self.indices[idx]

        observations = self.obs[trajectory_id][start:end]
        actions = self.action[trajectory_id][start:end]

        conditions = self.get_conditions(observations)
        trajectories = np.concatenate([actions, observations], axis=-1)
        batch = Batch(trajectories, conditions)
        return batch


class AdvDataset(CarlaDataset):
    '''
        adds a value field to the datapoints for training the value function
    '''

    def __init__(self, *args, discount=0.99, **kwargs):
        super().__init__(*args, **kwargs)
        self.discount = discount
        self.discounts = self.discount ** np.arange(len(self.reward))[:,None]
        assert len(self.obs) == len(self.indices)

    def __getitem__(self, idx):
        batch = super().__getitem__(idx)
        trajectory_id, start, end = self.indices[idx]
        rewards = self.reward[trajectory_id][start:end]
        assert sum(rewards[1:]) == 0, idx
        discounts = self.discounts[:len(rewards)]
        value = (discounts * rewards).sum()
        value = np.array([value], dtype=np.float32)
        value_batch = ValueBatch(*batch, value)
        return value_batch

# class SequenceDataset(torch.utils.data.Dataset):
#
#     def __init__(self, env='hopper-medium-replay', horizon=64,
#         normalizer='LimitsNormalizer', preprocess_fns=[], max_path_length=1000,
#         max_n_episodes=10000, termination_penalty=0, use_padding=True, seed=None):
#         self.preprocess_fn = get_preprocess_fn(preprocess_fns, env)
#         self.env = env = load_environment(env)
#         self.env.seed(seed)
#         self.horizon = horizon
#         self.max_path_length = max_path_length
#         self.use_padding = use_padding
#         itr = sequence_dataset(env, self.preprocess_fn)
#
#         fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)
#         for i, episode in enumerate(itr):
#             fields.add_path(episode)
#         fields.finalize()
#
#         self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths'])
#         self.indices = self.make_indices(fields.path_lengths, horizon)
#
#         self.observation_dim = fields.observations.shape[-1]
#         self.action_dim = fields.actions.shape[-1]
#         self.fields = fields
#         self.n_episodes = fields.n_episodes
#         self.path_lengths = fields.path_lengths
#         self.normalize()
#
#         print(fields)
#         # shapes = {key: val.shape for key, val in self.fields.items()}
#         # print(f'[ datasets/mujoco ] Dataset fields: {shapes}')
#
#     def normalize(self, keys=['observations', 'actions']):
#         '''
#             normalize fields that will be predicted by the diffusion model
#         '''
#         for key in keys:
#             array = self.fields[key].reshape(self.n_episodes*self.max_path_length, -1)
#             normed = self.normalizer(array, key)
#             self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)
#
#     def make_indices(self, path_lengths, horizon):
#         '''
#             makes indices for sampling from dataset;
#             each index maps to a datapoint
#         '''
#         indices = []
#         for i, path_length in enumerate(path_lengths):
#             max_start = min(path_length - 1, self.max_path_length - horizon)
#             if not self.use_padding:
#                 max_start = min(max_start, path_length - horizon)
#             for start in range(max_start):
#                 end = start + horizon
#                 indices.append((i, start, end))
#         indices = np.array(indices)
#         return indices
#
#     def get_conditions(self, observations):
#         '''
#             condition on current observation for planning
#         '''
#         return {0: observations[0]}
#
#     def __len__(self):
#         return len(self.indices)
#
#     def __getitem__(self, idx, eps=1e-4):
#         path_ind, start, end = self.indices[idx]
#
#         observations = self.fields.normed_observations[path_ind, start:end]
#         actions = self.fields.normed_actions[path_ind, start:end]
#
#         conditions = self.get_conditions(observations)
#         trajectories = np.concatenate([actions, observations], axis=-1)
#         batch = Batch(trajectories, conditions)
#         return batch
#
#
# class GoalDataset(SequenceDataset):
#
#     def get_conditions(self, observations):
#         '''
#             condition on both the current observation and the last observation in the plan
#         '''
#         return {
#             0: observations[0],
#             self.horizon - 1: observations[-1],
#         }
#
#
# class ValueDataset(SequenceDataset):
#     '''
#         adds a value field to the datapoints for training the value function
#     '''
#
#     def __init__(self, *args, discount=0.99, normed=False, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.discount = discount
#         self.discounts = self.discount ** np.arange(self.max_path_length)[:,None]
#         self.normed = False
#         if normed:
#             self.vmin, self.vmax = self._get_bounds()
#             self.normed = True
#
#     def _get_bounds(self):
#         print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True)
#         vmin = np.inf
#         vmax = -np.inf
#         for i in range(len(self.indices)):
#             value = self.__getitem__(i).values.item()
#             vmin = min(value, vmin)
#             vmax = max(value, vmax)
#         print('✓')
#         return vmin, vmax
#
#     def normalize_value(self, value):
#         ## [0, 1]
#         normed = (value - self.vmin) / (self.vmax - self.vmin)
#         ## [-1, 1]
#         normed = normed * 2 - 1
#         return normed
#
#     def __getitem__(self, idx):
#         batch = super().__getitem__(idx)
#         path_ind, start, end = self.indices[idx]
#         rewards = self.fields['rewards'][path_ind, start:]
#         discounts = self.discounts[:len(rewards)]
#         value = (discounts * rewards).sum()
#         if self.normed:
#             value = self.normalize_value(value)
#         value = np.array([value], dtype=np.float32)
#         value_batch = ValueBatch(*batch, value)
#         return value_batch
