import random

import numpy as np
from buffer import ReplayBuffer
from collections import namedtuple
from normalization import GaussianNormalizer
import torch

RewardBatch = namedtuple('Batch', 'trajectories conditions returns')
Batch = namedtuple('Batch', 'trajectories conditions')
SARBatch = namedtuple('Batch', 'outputs conditions rewards terminals outputs_r')


class SARPairs():
    def __init__(self, data):

        self.fields = dict()

        data = self.normalize(data)

        self.fields['dones'] = data['terminals'].astype(np.float32)
        self.fields['observations'] = data['observations']
        self.observation_dim = len(self.fields['observations'][0])
        self.fields['actions'] = data['actions']
        self.action_dim = len(self.fields['actions'][0])
        self.fields['rewards'] = data['rewards']
        self.fields['next_observations'] = data['next_observations']

        self.dim_max = np.max(data['observations'], axis=0)
        self.dim_min = np.min(data['observations'], axis=0)
        self.rew_max, self.rew_min = np.max(data['rewards']), np.min(data['rewards'])
        self.conditions = np.concatenate([self.fields['observations'], data['norm_actions']], axis=1)
        self.s_and_r = np.concatenate([self.fields['next_observations'], self.fields['rewards']], axis=1)
        self.length = self.conditions.shape[0]

    def to_tensor(self):
        self.fields['next_observations'] = torch.tensor(self.fields['next_observations'], dtype=torch.float32, device='cuda')
        self.conditions = torch.tensor(self.conditions, dtype=torch.float32, device='cuda')
        self.fields['rewards'] = torch.tensor(self.fields['rewards'], dtype=torch.float32, device='cuda')
        self.fields['dones'] = torch.tensor(self.fields['dones'], dtype=torch.float32, device='cuda')
        self.s_and_r = torch.tensor(self.s_and_r, dtype=torch.float32, device='cuda')

    def __len__(self):
        return self.length
    def __getitem__(self, idx):
        batch = SARBatch(self.fields['next_observations'][idx], self.conditions[idx], self.fields['rewards'][idx],
                         self.fields['dones'][idx], self.s_and_r[idx])
        return batch

    def normalize(self, data):
        self.normalizer = {}
        # only observations need to be normalized
        for key in ['observations', 'rewards', 'actions', 'terminals', 'next_observations']:
            if data[key].ndim < 2:
                data[key] = data[key][:, None]
            if key == 'observations' or key == 'actions':
                self.normalizer[key] = GaussianNormalizer(data[key])
                if key == 'observations':
                    data[key] = self.normalizer[key](data[key])
                else:
                    data['norm_actions'] = self.normalizer[key](data[key])
            if key == 'next_observations':
                data[key] = self.normalizer['observations'](data[key])
        return data

    def samplings0(self):
        index = random.randint(0, self.length-1)
        return self.fields['observations'][index]

    def get_observations(self):
        return self.fields['observations']

class SequenceDataset():
    def __init__(self, data, horizon=30, max_n_episodes=10000, max_path_length=300, termination_penalty=0,
                 discount=0.99,
                 include_returns=True, returns_scale=300) -> None:
        self.horizon = horizon
        fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)

        self.max_path_length = max_path_length
        self.include_returns = include_returns
        self.returns_scale = returns_scale

        self.discount = discount
        self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]

        for i in range(len(data)):
            fields.add_path(data[i])
        fields.finalize()
        self.indices = self.make_indices(fields.path_lengths, horizon)
        self.fields = fields

    def get_conditions(self, observations):
        '''
            condition on current observation for planning
        '''
        return {0: observations[0]}

    def make_indices(self, path_lengths, horizon):
        '''
            makes indices for sampling from dataset;
            each index maps to a datapoint
        '''
        indices = []
        for i, path_length in enumerate(path_lengths):
            max_start = min(path_length - horizon, self.max_path_length - horizon)
            for start in range(max_start):
                end = start + horizon
                indices.append((i, start, end))
        indices = np.array(indices)
        return indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        if type(idx) == int:
            # print(self.indices[idx])
            path_ind, start, end = self.indices[idx]

            observations = self.fields.observations[path_ind, start:end]
            actions = self.fields.actions[path_ind, start:end]

            conditions = self.get_conditions(observations)
            trajectories = np.concatenate([actions, observations], axis=-1)
            if self.include_returns:
                rewards = self.fields.rewards[path_ind, start:]
                discounts = self.discounts[:len(rewards)]
                returns = (discounts * rewards).sum()
                returns = np.array([returns / self.returns_scale], dtype=np.float32)
                batch = RewardBatch(trajectories, conditions, returns)
            else:
                batch = Batch(trajectories, conditions)

            return batch
        else:
            indexs = self.indices[idx]
            conditions = []
            returns = []
            trajectories = []
            for i in range(len(indexs)):
                path_ind, start, end = indexs[i]
                observation = self.fields.observations[path_ind, start:end]
                action = self.fields.actions[path_ind, start:end]
                condition = self.get_conditions(observation)
                trajectorie = np.concatenate([action, observation], axis=-1)
                trajectories.append(trajectorie)
                conditions.append(observation[0])
                if self.include_returns:
                    reward = self.fields.rewards[path_ind, start:]
                    discount = self.discounts[:len(reward)]
                    return_ = (discount * reward).sum()
                    return_ = np.array([return_ / self.returns_scale], dtype=np.float32)
                    returns.append(return_)

            conditions = {0: conditions}
            if self.include_returns:
                batch = RewardBatch(trajectories, conditions, returns)
            else:
                batch = Batch(trajectories, conditions)
            return batch
