from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image, ImageFilter
import util.util as util
import os
import h5py
import torch
import numpy as np


#### Positional Embedding codes are from official NeRF github https://github.com/yenchenlin/nerf-pytorch
class Embedder():
    def __init__(self, multires, embed_kwargs):

        self.kwargs = embed_kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


class RLDataset(BaseDataset):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument('--no_pairing_check', action='store_true',
                            help='If specified, skip sanity check of correct label-image file pairing')
        return parser

    def initialize(self, opt):
        self.opt = opt

        data_paths = self.get_paths(opt)

        self.data_pickle = h5py.File(data_paths, 'r')

        size = len(self.data_pickle['image_observations'])

        self.dataset_size = size

        if 'cheetah' in opt.env_type:
            self.state_num = 17
        elif 'walker' in opt.env_type:
            self.state_num = 24
        elif 'ballincup' in opt.env_type:
            self.state_num = 8
        elif 'cartpole' in opt.env_type:
            self.state_num = 5
        elif 'finger' in opt.env_type:
            self.state_num = 9
        elif 'reacher' in opt.env_type:
            self.state_num = 6
        else:
            self.state_num = None

        multires = 10
        embed_kwargs = {
            'include_input': True,
            'input_dims': self.state_num,
            'max_freq_log2': multires - 1,
            'num_freqs': multires,
            'log_sampling': True,
            'periodic_fns': [torch.sin, torch.cos]
        }

        self.embed = Embedder(multires, embed_kwargs)

    def get_paths(self, opt):
        label_paths = []

        assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"

        return data_paths

    def paths_match(self, path1, path2):
        filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
        filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
        return filename1_without_ext == filename2_without_ext

    def __getitem__(self, index):

        image_m1 = self.data_pickle['image_observations_tm1'][index]
        image = self.data_pickle['image_observations'][index]
        image_p1 = self.data_pickle['image_observations_tp1'][index]


        image_m1 = Image.fromarray(image_m1)
        image = Image.fromarray(image)
        image_p1 = Image.fromarray(image_p1)

        params = get_params(self.opt, image.size)

        transform_image = get_transform(self.opt, params)

        image_tensor_m1 = transform_image(image_m1)
        image_tensor = transform_image(image)
        image_tensor_p1 = transform_image(image_p1)


        state = torch.Tensor(np.array(self.data_pickle['observations'][index]))
        state_p1 = torch.Tensor(np.array(self.data_pickle['next_observations'][index]))

        state_embed = self.embed.embed(state)
        state_p1_embed = self.embed.embed(state_p1)

        input_dict = {'im_m1': image_tensor_m1,
                      'im': image_tensor,
                      'im_p1': image_tensor_p1,
                      'state': state,
                      'state_p1': state_p1,
                      'state_embed': state_embed,
                      'state_p1_embed': state_p1_embed
                      }

        # Give subclasses a chance to modify the final output
        self.postprocess(input_dict)

        return input_dict

    def postprocess(self, input_dict):
        return input_dict

    def __len__(self):
        return self.dataset_size
