import pickle
import numpy as np
import torch


def convert_to_tensor(x, device):
    return torch.tensor(np.asarray(x)).float().to(device)

def calc_cum_rewards(rewards, gamma, device):
    rewards = rewards.squeeze().to(device) # rewards should have size: num_trajectory x horizon 
    horizon = rewards.shape[-1]
    gs = torch.tensor([gamma]*horizon).to(device)
    d_gs = torch.cumprod(gs,dim=-1).repeat(len(rewards),1) # d_gs should have size: num_trajectory x horizon
    d_r = rewards * d_gs
    cdr = torch.cumsum(d_r.flip(-1),dim=-1).flip(-1)
    cdr = cdr/d_gs # this is not necessary since we are only interested in the rank of the cumulative rewards
    return cdr

class Dataset(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, config):
        self.device = config['device']
        self.shuffle = False
        self.horizon = config['horizon']
        self.gamma = config['gamma']
        self.K = config['K']
        self.prompt_K = self.K // 4

        # if path is not a list
        if not isinstance(path, list):
            path = [path]
        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
            
        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []

        for traj in self.trajs:
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])

        context_states = np.array(context_states)
        context_actions = np.array(context_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]

        self.dataset = {
            'context_states': convert_to_tensor(context_states, device=self.device),
            'context_actions': convert_to_tensor(context_actions, device=self.device),
            'context_next_states': convert_to_tensor(context_next_states, device=self.device),
            'context_rewards': convert_to_tensor(context_rewards, device=self.device),
        }
        self.dataset['cumulative_rewards'] = calc_cum_rewards(
                            rewards=self.dataset['context_rewards'],
                            gamma=self.gamma,
                            device=self.device)        

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset['context_states'])

    def __getitem__(self, index): #-> shape: (1, k)
        'Generates one sample of data'
        dataset_size = len(self.dataset['context_states'])
        start = np.random.randint(0, self.horizon - self.K)
        prompt_idx = np.random.randint(-dataset_size, dataset_size) + index
        prompt_idx = min(max(0, prompt_idx), dataset_size - 1)
        res = {
            'start': start,
            'states': self.dataset['context_states'][index][start:start+self.K],    
            'actions': self.dataset['context_actions'][index][start:start+self.K],
            'returns_to_go': self.dataset['cumulative_rewards'][index][start:start+self.K][:, None],
            'timesteps': torch.arange(start, start+self.K),
            'prompt_states': self.dataset['context_states'][prompt_idx][:self.prompt_K],
            'prompt_actions': self.dataset['context_actions'][prompt_idx][:self.prompt_K],
            'prompt_returns_to_go': self.dataset['cumulative_rewards'][prompt_idx][:self.prompt_K][:, None],
            'prompt_timesteps': torch.arange(0, self.prompt_K),
        }

        return res

