import torch
import numpy as np

import pickle

from .utils import convert_to_tensor

def calc_cum_rewards(rewards, gamma,device):
    rewards = rewards.squeeze().to(device) # rewards should have size: num_trajectory x horizon 
    horizon = rewards.shape[-1]
    
    gs = torch.tensor([gamma]*horizon).to(device)
    d_gs = torch.cumprod(gs,dim=-1).repeat(len(rewards),1) # d_gs should have size: num_trajectory x horizon

    d_r = rewards * d_gs
    cdr = torch.cumsum(d_r.flip(-1),dim=-1).flip(-1)
    cdr = cdr/d_gs

    return cdr.detach().cpu()

class NormalDatasetBatch(torch.utils.data.Dataset):
    """This is a normal dataset."""
    def __init__(self, path, config, device, num_trajectory=20000):
        # Initializing the Configs
        self._config = config
        self._device = device
        
        self._shuffle = self._config.shuffle  # I really don't think that this is needed.
        self._horizon = self._config.horizon
        self._store_gpu = self._config.store_gpu
        
        if not isinstance(path, list):
            path = [path]
            
        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
        
        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []
        query_states = []
        optimal_actions = []
        
        for traj in self.trajs:
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])
            
            query_states.append(traj['query_state'])
            optimal_actions.append(traj['optimal_action'])
            
        context_states = np.array(context_states[:num_trajectory])
        context_actions = np.array(context_actions[:num_trajectory])
        context_next_states = np.array(context_next_states[:num_trajectory])
        context_rewards = np.array(context_rewards[:num_trajectory])
        
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]
            
        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)
        self.dataset = {
            'context_states': convert_to_tensor(context_states, device=self._device, store_gpu=self._store_gpu),
            'context_actions': convert_to_tensor(context_actions, device=self._device, store_gpu=self._store_gpu),
            'context_next_states': convert_to_tensor(context_next_states, device=self._device, store_gpu=self._store_gpu),
            'context_rewards': convert_to_tensor(context_rewards, device=self._device, store_gpu=self._store_gpu),
            'query_states': convert_to_tensor(query_states, device=self._device, store_gpu=self._store_gpu),
            'optimal_actions': convert_to_tensor(optimal_actions, device=self._device, store_gpu=self._store_gpu)
        }
        
    def __len__(self):
        return len(self.dataset["context_states"])
    
    def __getitem__(self, index):
        res = {
            'context_states': self.dataset['context_states'][index].to(self._device),
            'context_actions': self.dataset['context_actions'][index].to(self._device),
            'context_next_states': self.dataset['context_next_states'][index].to(self._device),
            'context_rewards': self.dataset['context_rewards'][index].to(self._device),
            'query_states': self.dataset['query_states'][index].to(self._device),
            'optimal_actions': self.dataset['optimal_actions'][index].to(self._device)
        }
        if self._shuffle:
            perm = torch.randperm(self._horizon)
            res['context_states'] = res['context_states'][perm]
            res['context_actions'] = res['context_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]
            
        return res
        

class PreferenceDatasetBatch(torch.utils.data.Dataset):
    """
    Dataset class for the batch data.
    It requires that the data is in the following format:
    {
        'traj_1': {
            'context_states': Tensor of shape (batch_size, context_length, state_dim)
            'context_actions': Tensor of shape (batch_size, context_length, action_dim)
            'context_next_states': Tensor of shape (batch_size, context_length, state_dim)
            'context_rewards': Tensor of shape (batch_size, context_length, 1)   
        },
        'traj_2': {
            'context_states': Tensor of shape (batch_size, context_length, state_dim)
            'context_actions': Tensor of shape (batch_size, context_length, action_dim)
            'context_next_states': Tensor of shape (batch_size, context_length, state_dim)
            'context_rewards': Tensor of shape (batch_size, context_length, 1)  
        }
        'preference': int,
        'preference_probs': List of two floats, [preference_prob_traj_1, preference_prob_traj_2]
    }
    """
    def __init__(self, path, config, device, num_pairs=50000):
        # Initializing the Configs
        self._config = config
        self._device = device
        
        self._shuffle = self._config.shuffle  # I really don't think that this is needed.
        self._horizon = self._config.horizon
        self._store_gpu = self._config.store_gpu
        
        if not isinstance(path, list):
            path = [path]
        
        self.traj_pairs = []
        for p in path:
            with open(p, 'rb') as f:
                self.traj_pairs += pickle.load(f)
        indices = np.random.choice(range(len(self.traj_pairs)),min(num_pairs,len(self.traj_pairs)),replace=False)
        self.traj_pairs = [self.traj_pairs[ind] for ind in indices]
        
        traj_1_context_states = []
        traj_1_context_actions = []
        traj_1_context_next_states = []
        traj_1_context_rewards = []
        
        traj_2_context_states = []
        traj_2_context_actions = []
        traj_2_context_next_states = []
        traj_2_context_rewards = []
        
        query_states = []
        optimal_actions = []
        
        for traj_pair in self.traj_pairs:
            traj_1 = traj_pair['traj_1']
            traj_2 = traj_pair['traj_2']
            preference = traj_pair['preference']
            
            # if preference == 0, traj_1 is preferred, otherwise traj_2 is preferred.
            if preference == 0:
                traj_1_context_states.append(traj_1['context_states'])
                traj_1_context_actions.append(traj_1['context_actions'])
                traj_1_context_next_states.append(traj_1['context_next_states'])
                traj_1_context_rewards.append(traj_1['context_rewards'])
                
                traj_2_context_states.append(traj_2['context_states'])
                traj_2_context_actions.append(traj_2['context_actions'])
                traj_2_context_next_states.append(traj_2['context_next_states'])
                traj_2_context_rewards.append(traj_2['context_rewards'])
            else:
                traj_1_context_states.append(traj_2['context_states'])
                traj_1_context_actions.append(traj_2['context_actions'])
                traj_1_context_next_states.append(traj_2['context_next_states'])
                traj_1_context_rewards.append(traj_2['context_rewards'])
                
                traj_2_context_states.append(traj_1['context_states'])
                traj_2_context_actions.append(traj_1['context_actions'])
                traj_2_context_next_states.append(traj_1['context_next_states'])
                traj_2_context_rewards.append(traj_1['context_rewards'])
                
            # query_states.append(traj_pair['query_state'])
            # optimal_actions.append(traj_pair['optimal_action'])

        traj_1_context_states = np.array(traj_1_context_states)
        traj_1_context_actions = np.array(traj_1_context_actions)
        traj_1_context_next_states = np.array(traj_1_context_next_states)
        traj_1_context_rewards = np.array(traj_1_context_rewards)
        traj_2_context_states = np.array(traj_2_context_states)
        traj_2_context_actions = np.array(traj_2_context_actions)
        traj_2_context_next_states = np.array(traj_2_context_next_states)
        traj_2_context_rewards = np.array(traj_2_context_rewards)
        
        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)
        
        self.dataset = {
            'traj_1_context_states': convert_to_tensor(traj_1_context_states, device=self._device, store_gpu=self._store_gpu),
            'traj_1_context_actions': convert_to_tensor(traj_1_context_actions, device=self._device, store_gpu=self._store_gpu),
            'traj_1_context_next_states': convert_to_tensor(traj_1_context_next_states, device=self._device, store_gpu=self._store_gpu),
            'traj_1_context_rewards': convert_to_tensor(traj_1_context_rewards, device=self._device, store_gpu=self._store_gpu),
            
            'traj_2_context_states': convert_to_tensor(traj_2_context_states, device=self._device, store_gpu=self._store_gpu),
            'traj_2_context_actions': convert_to_tensor(traj_2_context_actions, device=self._device, store_gpu=self._store_gpu),
            'traj_2_context_next_states': convert_to_tensor(traj_2_context_next_states, device=self._device, store_gpu=self._store_gpu),
            'traj_2_context_rewards': convert_to_tensor(traj_2_context_rewards, device=self._device, store_gpu=self._store_gpu),
            
            # 'query_states': convert_to_tensor(query_states, device=self._device, store_gpu=self._store_gpu),
            # 'optimal_actions': convert_to_tensor(optimal_actions, device=self._device, store_gpu=self._store_gpu)
        }
        
        
    def __len__(self):
        return len(self.dataset["traj_1_context_states"])
    
    def __getitem__(self, index):
        res = {
            "traj_1":
                {
                    'context_states': self.dataset['traj_1_context_states'][index][None, :, :].to(self._device),
                    'context_actions': self.dataset['traj_1_context_actions'][index][None, :, :].to(self._device),
                    'context_next_states': self.dataset['traj_1_context_next_states'][index][None, :, :].to(self._device),
                    'context_rewards': self.dataset['traj_1_context_rewards'][index][None, :].to(self._device)
                },
            "traj_2":
                {
                    'context_states': self.dataset['traj_2_context_states'][index][None, :, :].to(self._device),
                    'context_actions': self.dataset['traj_2_context_actions'][index][None, :, :].to(self._device),
                    'context_next_states': self.dataset['traj_2_context_next_states'][index][None, :, :].to(self._device),
                    'context_rewards': self.dataset['traj_2_context_rewards'][index][None, :].to(self._device)
                },
            "query_states": self.dataset['query_states'][index][None, :].to(self._device),
            "optimal_actions": self.dataset['optimal_actions'][index][None, :].to(self._device)
        }
        if self._shuffle:
            perm = torch.randperm(self.horizon)
            res['traj_1']['context_states'] = res['traj_1']['context_states'][:, perm, :]
            res['traj_1']['context_actions'] = res['traj_1']['context_actions'][:, perm, :]
            res['traj_1']['context_next_states'] = res['traj_1']['context_next_states'][:, perm, :]
            res['traj_1']['context_rewards'] = res['traj_1']['context_rewards'][:, perm]
            
            res['traj_2']['context_states'] = res['traj_2']['context_states'][:, perm, :]
            res['traj_2']['context_actions'] = res['traj_2']['context_actions'][:, perm, :]
            res['traj_2']['context_next_states'] = res['traj_2']['context_next_states'][:, perm, :]
            res['traj_2']['context_rewards'] = res['traj_2']['context_rewards'][:, perm]
            
        return res
    
    @staticmethod
    def batch_collate_fn(batch):
        """
        Collate function for the batch data.
        """
        batch_data = {
            'traj_1': {
                'context_states': torch.cat([b['traj_1']['context_states'] for b in batch], dim=0),
                'context_actions': torch.cat([b['traj_1']['context_actions'] for b in batch], dim=0),
                'context_next_states': torch.cat([b['traj_1']['context_next_states'] for b in batch], dim=0),
                'context_rewards': torch.cat([b['traj_1']['context_rewards'] for b in batch], dim=0)
            },
            'traj_2': {
                'context_states': torch.cat([b['traj_2']['context_states'] for b in batch], dim=0),
                'context_actions': torch.cat([b['traj_2']['context_actions'] for b in batch], dim=0),
                'context_next_states': torch.cat([b['traj_2']['context_next_states'] for b in batch], dim=0),
                'context_rewards': torch.cat([b['traj_2']['context_rewards'] for b in batch], dim=0)
            },
            # 'query_states': torch.cat([b['query_states'] for b in batch], dim=0),
            # 'optimal_actions': torch.cat([b['optimal_actions'] for b in batch], dim=0)
        }
        
        return batch_data