import torch
import models
import sys
from collections import OrderedDict
from options.test_options import TestOptions
import data
from util.util import tensor2im
import matplotlib.pyplot as plt
import h5py
import numpy as np
import torch.nn.functional as F
from PIL import Image
from data.base_dataset import BaseDataset, get_params, get_transform
import os

#### Positional Embedding codes are from official NeRF github https://github.com/yenchenlin/nerf-pytorch
class Embedder():
    def __init__(self, embed_kwargs):

        self.kwargs = embed_kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


opt = TestOptions().parse()


if 'cheetah' in opt.env_type:
    state_num = 17
    dataset = 'cheetah'
elif 'walker' in opt.env_type:
    state_num = 24
    if 'run' in opt.env_type:
        dataset = 'walker_run'
    else:
        dataset = 'walker'
elif 'ballincup' in opt.env_type:
    state_num = 8
    dataset = 'ballincup'
elif 'finger' in opt.env_type:
    state_num = 9
    dataset = 'finger'
elif 'cartpole' in opt.env_type:
    state_num = 5
    dataset = 'cartpole'
elif 'reacher' in opt.env_type:
    state_num = 6
    dataset = 'reacher'
else:
    state_num = None
    dataset = None



def generate_image(image, state_p1):
    # image : [h,w,c]

    # index = 0
    # image_m1 = data_pickle['image_observations_tm1'][index] # numpy array (H * W * C)
    # image = data_pickle['image_observations'][index] # numpy array (H * W * C)

    state_p1 = torch.Tensor(state_p1)  # numpy array (state_num)
    image = Image.fromarray(image)

    params = get_params(opt, image.size)

    transform_image = get_transform(opt, params)

    image_tensor = transform_image(image).unsqueeze(dim=0)

    image_tensor = image_tensor.cuda()

    state_p1_embed = embed.embed(state_p1).unsqueeze(dim=0)
    state_p1_embed = state_p1_embed.cuda()

    input_semantics = state_p1_embed

    # back_imgs = torch.cat((image_tensor_m1, image_tensor), dim=1)
    back_imgs = image_tensor
    import time
    image_tensor_p1 = netG(input_semantics, back_imgs)
    image_tensor_p1 = F.interpolate(image_tensor_p1, size=(100, 100), mode='bilinear')

    image_p1 = ((image_tensor_p1 + 1.0) / 2.0).detach().permute(0, 2, 3, 1).cpu().numpy()
    image_p1 = (image_p1 * 255).astype(np.uint8).squeeze(0)  # [1,h,w,c] ->[h,w,c]

    return image_p1



multires = 10

embed = Embedder(embed_kwargs={
    'include_input': True,
    'input_dims': state_num,
    'max_freq_log2': multires - 1,
    'num_freqs': multires,
    'log_sampling': True,
    'periodic_fns': [torch.sin, torch.cos]
})


data = h5py.File(os.path.join(opt.dataroot, dataset + '.hdf5'), 'r')
netG = models.networks.define_G(opt)
ckpt = torch.load('./checkpoints/%s_%d.pth'%(dataset, opt.which_epoch))
netG.load_state_dict(ckpt)
netG.eval()

print('Dataset size for simple evaluatio:%d'%len(data['image_observations']))
seq_len = opt.seq_len
idx = opt.start_idx

img = data['image_observations'][idx]

plt.subplot(2, seq_len + 1, 1)
plt.title('GT t=0')
plt.imshow(img)
plt.axis('off')

plt.subplot(2, seq_len + 1, seq_len + 1 + 1)
plt.title('GT t=%d'%0)
plt.imshow(img)
plt.axis('off')

for i in range(seq_len):

    state_p1 = data['next_observations'][idx + i]

    gt = data['image_observations_tp1'][idx + i]

    img_new = generate_image(img, state_p1)

    plt.subplot(2, seq_len + 1, i + 2)
    plt.title('GT t=%d' % (i + 1))
    plt.imshow(gt)
    plt.axis('off')

    plt.subplot(2, seq_len + 1, seq_len + 1 + i + 2)
    plt.title('Syn t=%d' % (i + 1))
    plt.imshow(img_new)
    plt.axis('off')

    img_new = img

plt.show()





