import torch
import numpy as np

import pickle

from utils import convert_to_tensor


def calc_cum_rewards(rewards, gamma,device):
    rewards = rewards.squeeze().to(device) # rewards should have size: num_trajectory x horizon 
    horizon = rewards.shape[-1]
    
    gs = torch.tensor([gamma]*horizon).to(device)
    d_gs = torch.cumprod(gs,dim=-1).repeat(len(rewards),1) # d_gs should have size: num_trajectory x horizon

    d_r = rewards * d_gs
    cdr = torch.cumsum(d_r.flip(-1),dim=-1).flip(-1)
    cdr = cdr/d_gs

    return cdr.detach().cpu()

class PrefDatasetBatch(torch.utils.data.Dataset):
    """Preference Dataset class."""

    def __init__(self, path, config, device, gamma=0.8, num_trajs=50000):
        self.shuffle = config['shuffle']
        self.horizon = config['horizon']
        self.store_gpu = config['store_gpu']
        self.config = config
        self.device = device
    
        # if path is not a list
        if not isinstance(path, list):
            path = [path]

        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
       
        indices = np.random.choice(range(len(self.trajs)),min(num_trajs,len(self.trajs)),replace=False)
        self.trajs = [self.trajs[ind] for ind in indices]
            
        context_states = []
        context_pr_actions = []
        context_npr_actions = []
        context_next_states = []
        context_rewards = []
        query_states = []
        optimal_actions = []

        for traj in self.trajs:
            context_states.append(traj['context_states'])
            context_pr_actions.append(traj['pref_actions'])
            context_npr_actions.append(traj['non_pref_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])
            query_states.append(traj['query_state'])
            optimal_actions.append(traj['optimal_action'])

        context_states = np.array(context_states)
        context_pr_actions = np.array(context_pr_actions)
        context_npr_actions = np.array(context_npr_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]
        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)

        self.dataset = {
            'query_states': convert_to_tensor(query_states, self.device,store_gpu=self.store_gpu),
            'optimal_actions': convert_to_tensor(optimal_actions, self.device,store_gpu=self.store_gpu),
            'context_states': convert_to_tensor(context_states,self.device, store_gpu=self.store_gpu),
            'context_pr_actions': convert_to_tensor(context_pr_actions, self.device,store_gpu=self.store_gpu),
            'context_npr_actions': convert_to_tensor(context_npr_actions, self.device,store_gpu=self.store_gpu),
            'context_next_states': convert_to_tensor(context_next_states, self.device,store_gpu=self.store_gpu),
            'context_rewards': convert_to_tensor(context_rewards, self.device,store_gpu=self.store_gpu),
        }

        # print('calculating coefficients')
        # self.dataset['context_cum_rewards'] = calc_cum_rewards(gamma=gamma,\
        #                                                   rewards=self.dataset['context_rewards'],\
        #                                                   device=self.device)
        # # Normalizing Rewards
        # mean_ccr = self.dataset['context_cum_rewards'].mean(dim=-1).unsqueeze(-1).repeat(1,self.horizon)
        # self.dataset['context_cum_rewards'] = self.dataset['context_cum_rewards'] - mean_ccr
        # self.dataset['context_cum_rewards'][self.dataset['context_cum_rewards'] < 0] = 0.
        
        self.zeros = np.zeros(config['state_dim'] ** 2 + config['action_dim'] + 1) # This will give you zeros more than required 
        self.zeros = convert_to_tensor(self.zeros, self.device, store_gpu=self.store_gpu)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset['query_states'])

    def __getitem__(self, index):
        'Generates one sample of data'
        res = {
            'context_states': self.dataset['context_states'][index],
            'context_pr_actions': self.dataset['context_pr_actions'][index],
            'context_npr_actions': self.dataset['context_npr_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'query_states': self.dataset['query_states'][index],
            'optimal_actions': self.dataset['optimal_actions'][index],
            'zeros': self.zeros,
        }

        if self.shuffle:
            perm = torch.randperm(self.horizon)
            res['context_states'] = res['context_states'][perm]
            res['context_pr_actions'] = res['context_pr_actions'][perm]
            res['context_npr_actions'] = res['context_npr_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]

        return res

class DatasetBatch(torch.utils.data.Dataset):
    """Dataset class."""

    def __init__(self, path, config, device, gamma=0.8, num_trajs=50000):
        self.shuffle = config['shuffle']
        self.horizon = config['horizon']
        self.store_gpu = config['store_gpu']
        self.config = config
        self.device = device
    

        # if path is not a list
        if not isinstance(path, list):
            path = [path]

        self.trajs = []
        for p in path:
            with open(p, 'rb') as f:
                self.trajs += pickle.load(f)
       
        indices = np.random.choice(range(len(self.trajs)),min(num_trajs,len(self.trajs)),replace=False)
        self.trajs = [self.trajs[ind] for ind in indices]
            
        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []
        query_states = []
        optimal_actions = []

        for traj in self.trajs:
            context_states.append(traj['context_states'])
            context_actions.append(traj['context_actions'])
            context_next_states.append(traj['context_next_states'])
            context_rewards.append(traj['context_rewards'])
            query_states.append(traj['query_state'])
            optimal_actions.append(traj['optimal_action'])

        context_states = np.array(context_states)
        context_actions = np.array(context_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]
        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)

        self.dataset = {
            'query_states': convert_to_tensor(query_states, self.device,store_gpu=self.store_gpu),
            'optimal_actions': convert_to_tensor(optimal_actions, self.device,store_gpu=self.store_gpu),
            'context_states': convert_to_tensor(context_states,self.device, store_gpu=self.store_gpu),
            'context_actions': convert_to_tensor(context_actions, self.device,store_gpu=self.store_gpu),
            'context_next_states': convert_to_tensor(context_next_states, self.device,store_gpu=self.store_gpu),
            'context_rewards': convert_to_tensor(context_rewards, self.device,store_gpu=self.store_gpu),
        }

        print('calculating coefficients')
        self.dataset['context_cum_rewards'] = calc_cum_rewards(gamma=gamma,\
                                                          rewards=self.dataset['context_rewards'],\
                                                          device=self.device)
        # mean_ccr = self.dataset['context_cum_rewards'].mean(dim=-1).unsqueeze(-1).repeat(1,self.horizon)
        # self.dataset['context_cum_rewards'] = self.dataset['context_cum_rewards'] - mean_ccr
        # self.dataset['context_cum_rewards'][self.dataset['context_cum_rewards'] < 0] = 0.
        
        self.zeros = np.zeros(
            config['state_dim'] ** 2 + config['action_dim'] + 1
        )
        self.zeros = convert_to_tensor(self.zeros, self.device, store_gpu=self.store_gpu)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset['query_states'])

    def __getitem__(self, index):
        'Generates one sample of data'
        res = {
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'query_states': self.dataset['query_states'][index],
            'optimal_actions': self.dataset['optimal_actions'][index],
            'zeros': self.zeros,
            'context_cum_rewards':self.dataset['context_cum_rewards'][index],
        }

        if self.shuffle:
            perm = torch.randperm(self.horizon)
            res['context_states'] = res['context_states'][perm]
            res['context_actions'] = res['context_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]
            res['context_cum_rewards'] = res['context_cum_rewards'][perm]

        return res


class ImageDataset(DatasetBatch):
    """"Dataset class for image-based data."""

    def __init__(self, paths, config, transform, device, num_trajs=40000):
        config['store_gpu'] = False
        super().__init__(paths, config, device, num_trajs)
        self.transform = transform
        self.config = config

        context_filepaths = []
        query_images = []

        for traj in self.trajs:
            context_filepaths.append(traj['context_images'])
            query_image = self.transform(traj['query_image']).float()
            query_images.append(query_image)

        self.dataset.update({
            'context_filepaths': context_filepaths,
            'query_images': torch.stack(query_images),
        })

    def __getitem__(self, index):
        'Generates one sample of data'
        filepath = self.dataset['context_filepaths'][index]
        context_images = np.load(filepath)
        context_images = [self.transform(images) for images in context_images]
        context_images = torch.stack(context_images).float()

        query_images = self.dataset['query_images'][index]

        res = {
            'context_images': context_images,#.to(device),
            'context_states': self.dataset['context_states'][index],
            'context_actions': self.dataset['context_actions'][index],
            'context_next_states': self.dataset['context_next_states'][index],
            'context_rewards': self.dataset['context_rewards'][index],
            'query_images': query_images,#.to(device),
            'query_states': self.dataset['query_states'][index],
            'optimal_actions': self.dataset['optimal_actions'][index],
            'zeros': self.zeros,
            'context_cum_rewards':self.dataset['context_cum_rewards'][index],
        }

        if self.shuffle:
            perm = torch.randperm(self.horizon)
            res['context_images'] = res['context_images'][perm]
            res['context_states'] = res['context_states'][perm]
            res['context_actions'] = res['context_actions'][perm]
            res['context_next_states'] = res['context_next_states'][perm]
            res['context_rewards'] = res['context_rewards'][perm]
            res['context_cum_rewards'] = res['context_cum_rewards'][perm]

        return res