import pickle
import random
import numpy as np
import torch


def convert_to_tensor(x, device):
    return torch.tensor(np.asarray(x)).float().to(device)

def calc_cum_rewards(rewards, gamma, device):
    rewards = rewards.squeeze().to(device) # rewards should have size: num_trajectory x horizon 
    horizon = rewards.shape[-1]
    gs = torch.tensor([gamma]*horizon).to(device)
    d_gs = torch.cumprod(gs,dim=-1).repeat(len(rewards),1) # d_gs should have size: num_trajectory x horizon
    d_r = rewards * d_gs
    cdr = torch.cumsum(d_r.flip(-1),dim=-1).flip(-1)
    cdr = cdr/d_gs # this is not necessary since we are only interested in the rank of the cumulative rewards
    return cdr

class Dataset_DIT(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, config):
        self.device = config['device']
        self.horizon = config['horizon']
        self.gamma = config['gamma']
        # if path is not a list
        if not isinstance(path, list):
            path = [path]
        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
            
        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []

        for traj in self.trajs:
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])

        context_states = np.array(context_states)
        context_actions = np.array(context_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]

        self.dataset = {
            'context_states': convert_to_tensor(context_states, device=self.device),
            'context_actions': convert_to_tensor(context_actions, device=self.device),
            'context_next_states': convert_to_tensor(context_next_states, device=self.device),
            'context_rewards': convert_to_tensor(context_rewards, device=self.device),
        }
        self.dataset['cumulative_rewards'] = calc_cum_rewards(
                            rewards=self.dataset['context_rewards'],
                            gamma=self.gamma,
                            device=self.device)        

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset['context_states'])

    def __getitem__(self, index):
        'Generates one sample of data'
        res = {
            'query_idx': torch.randint(0, self.horizon, ()),            
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'cumulative_rewards': self.dataset['cumulative_rewards'][index],
        }
        return res



class Dataset_DPT(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, config):
        self.device = config['device']
        self.horizon = config['horizon']
        # if path is not a list
        if not isinstance(path, list):
            path = [path]
        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
        query_states = []  
        optimal_actions = []  
        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []

        for traj in self.trajs:
            query_states.append(traj['query_state'])
            optimal_actions.append(traj['optimal_action'])
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])

        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)
        context_states = np.array(context_states)
        context_actions = np.array(context_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]

        self.dataset = {
            'query_states': convert_to_tensor(query_states, device=self.device),
            'optimal_actions': convert_to_tensor(optimal_actions, device=self.device),
            'context_states': convert_to_tensor(context_states, device=self.device),
            'context_actions': convert_to_tensor(context_actions, device=self.device),
            'context_next_states': convert_to_tensor(context_next_states, device=self.device),
            'context_rewards': convert_to_tensor(context_rewards, device=self.device),
        }

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset['context_states'])

    def __getitem__(self, index):
        'Generates one sample of data'
        res = {
            'query_states': self.dataset['query_states'][index],
            'optimal_action': self.dataset['optimal_actions'][index],
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
        }
        return res
    

class Dataset_AD(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, config):
        self.device = config['device']
        self.horizon = config['horizon']
        self.gamma = config['gamma']
        
        # if path is not a list
        if not isinstance(path, list):
            path = [path]
        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs.append(pickle.load(f)) # need to load each trajectory sets seperately, by the policy quality
        
        context_states_all = []
        context_actions_all = []
        context_next_states_all = []
        context_rewards_all = []
        
        for trajs in self.trajs:
            context_states = []
            context_actions = []
            context_next_states = []
            context_rewards = []
            for traj in trajs:
                context_states.append(traj['context_states'])
                context_actions.append(traj['context_actions'])
                context_next_states.append(traj['context_next_states'])
                context_rewards.append(traj['context_rewards'])
            context_states_all.append(context_states)
            context_actions_all.append(context_actions)
            context_next_states_all.append(context_next_states)
            context_rewards_all.append(context_rewards)

        context_states = np.array(context_states_all)
        context_actions = np.array(context_actions_all)
        context_next_states = np.array(context_next_states_all)
        context_rewards = np.array(context_rewards_all)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]

        self.dataset = {
            'context_states': convert_to_tensor(context_states, device=self.device),
            'context_actions': convert_to_tensor(context_actions, device=self.device),
            'context_next_states': convert_to_tensor(context_next_states, device=self.device),
            'context_rewards': convert_to_tensor(context_rewards, device=self.device),
        }

    def __len__(self):
        'Denotes the total number of samples for each trajectory set'
        return len(self.dataset['context_states'][0])

    def __getitem__(self, index):
        'Generates one sample of data'
        traj_combo = random.choice([(0, 1), (1, 2), (0, 1)])
        second_index = np.random.randint(0, len(self.dataset['context_states'][0]))
        break_point = np.random.randint(0, self.horizon)
                                        
        res = {
            'query_idx': torch.randint(0, self.horizon, ()),            
            'context_states': torch.cat([self.dataset['context_states'][traj_combo[0]][index][break_point:], 
                                         self.dataset['context_states'][traj_combo[1]][second_index][:break_point]], dim=0),
            'context_actions': torch.cat([self.dataset['context_actions'][traj_combo[0]][index][break_point:], 
                                         self.dataset['context_actions'][traj_combo[1]][second_index][:break_point]], dim=0),
            'context_next_states': torch.cat([self.dataset['context_next_states'][traj_combo[0]][index][break_point:], 
                                         self.dataset['context_next_states'][traj_combo[1]][second_index][:break_point]], dim=0),
            'context_rewards': torch.cat([self.dataset['context_rewards'][traj_combo[0]][index][break_point:], 
                                         self.dataset['context_rewards'][traj_combo[1]][second_index][:break_point]], dim=0)[...,None],
        }
        return res    