import os
import numpy as np
import imageio
import json
import cv2
import pdb
import pickle
import json
import random
import time

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from utils import rand_int






class DynamicsDataset(Dataset):

    def __init__(self, args, phase, dec_eval=False):
        self.args = args
        self.phase = phase
        self.prestored = args.prestored
        self.dec_eval = dec_eval

        self.data_dir = os.path.join(args.dataf)

        self.frame_jump = args.frame_jump

        ratio = args.train_valid_ratio


        if phase == 'train':
            self.n_rollout = int(self.args.n_rollout * ratio)
        elif phase == 'valid':
            self.n_rollout = self.args.n_rollout - int(self.args.n_rollout * ratio)

        if phase == 'train':
            n_roll_per_scene = int(args.n_rollout) // 4
            rollout_list = []
            for i in range(4):
                rollout_list += [item for item in range(i * 100, i * 100 + n_roll_per_scene)]
        else:
            rollout_list = [item for item in range(400, 400 + self.n_rollout)]

        self.rollout_list = rollout_list

        # if phase == 'train':
        #     self.n_rollout = n_train
        #     self.rollout_st_idx = 0
        # elif phase == 'valid':
        #     self.n_rollout = n_valid
        #     self.rollout_st_idx = n_train

        if args.phase == 'dec' and dec_eval == False:
            assert args.n_his == 1
            assert args.n_roll == 0

            n_train_ts = int(args.n_timestep_for_dec * args.train_valid_ratio)
            n_valid_ts = args.n_timestep_for_dec - n_train_ts

            self.sample_ts_idx = np.random.choice(
                args.time_step * self.n_rollout,
                size=n_train_ts if phase == 'train' else n_valid_ts,
                replace=False)

        self.n_his = args.n_his
        self.n_roll = args.n_roll
        self.n_frames = args.n_frames
        self.time_step = args.time_step

        self.n_view = args.n_view

        self.load_metadata()

    def load_metadata(self):
        args = self.args
        self.action = []
        self.viewMatrix = []
        self.projMatrix = []
        print("=== Loading metadata for %s ..." % self.phase)
        for i in range(self.args.n_rollout):
            path = os.path.join(self.data_dir, '%d' % i, 'info.p')
            # print(path)
            if os.path.isfile(path):
                # print(path, "!!!!!!")
                metadata = pickle.load(open(path, 'rb'))
                self.action.append(metadata['action'])
                self.viewMatrix.append(metadata['viewMatrix'])
                self.projMatrix.append(metadata['projMatrix'])
                self.focal = metadata['projMatrix'][0, 0, 0, 0]
            else:
                # add zeros tensors in case metadata is not presented
                self.action.append(np.zeros((args.time_step, args.act_dim)))
                self.viewMatrix.append(np.zeros((20, args.time_step, 4, 4)))
                self.projMatrix.append(np.zeros((20, args.time_step, 4, 4)))

        print('action', np.array(self.action).shape)
        print('viewMatrix', np.array(self.viewMatrix).shape)
        print('projMatrix', np.array(self.projMatrix).shape)

    def __len__(self):
        if self.args.phase == 'dec' and self.dec_eval == False:
            return len(self.sample_ts_idx)
        else:
            return self.n_rollout * (self.time_step // self.frame_jump - self.n_his - self.n_roll + 1)

    def __getitem__(self, idx):

        # if idx < 80:
        #     return torch.FloatTensor(np.zeros((1, 1)))

        args = self.args

        offset = self.time_step // self.frame_jump - self.n_his - self.n_roll + 1

        if args.phase == 'dec' and self.dec_eval == False:
            idx_rollout = self.rollout_st_idx + self.sample_ts_idx[idx] // offset
            idx_timestep = self.sample_ts_idx[idx] % offset
        else:
            idx_rollout = self.rollout_list[idx // offset]
            idx_timestep = idx % offset
            # idx_rollout = self.rollout_st_idx + idx // offset + 1
            # idx_timestep = idx % offset + 100


        if args.prestored == 0:

            focal = self.focal

            imgs = []
            poses = []
            actions = []

            for i in range(idx_timestep, idx_timestep + self.n_his + self.n_roll):
                for j in range(self.n_view):
                    if self.phase == 'train':
                        idx_view = rand_int(0, self.n_frames)
                    elif self.phase == 'valid':
                        if j < args.n_view_enc:
                            idx_view = rand_int(0, self.n_frames)
                            idx_view = j
                        else:
                            idx_view = rand_int(0, self.n_frames)
                            # idx_view = (j - 4) * 7 + 2
                            idx_view = j

                    img_path = os.path.join(self.data_dir, '%d/%d/%d.jpg' % (
                        idx_rollout, idx_view, i * args.frame_jump))
                    rgb_img = imageio.imread(img_path)

                    if args.phase == 'ae':
                        # sample the negative img
                        while True:
                            idx_neg = rand_int(0, self.time_step // self.frame_jump)
                            if np.abs(idx_neg - i) >= (15 // self.frame_jump):
                                break
                        img_neg_path = os.path.join(self.data_dir, '%d/%d/%d.jpg' % (
                            idx_rollout, idx_view, idx_neg * self.frame_jump))
                        rgb_img_neg = imageio.imread(img_neg_path)

                    imgs.append(rgb_img)
                    if args.phase == 'ae':
                        imgs.append(rgb_img_neg)

                    poses.append(np.linalg.inv(np.transpose(
                        self.viewMatrix[idx_rollout][idx_view, i * args.frame_jump])))

                actions.append(self.action[idx_rollout][i * args.frame_jump])

            imgs = (np.array(imgs) / 255.).astype(np.float32)
            poses = np.array(poses).astype(np.float32)
            actions = np.array(actions).astype(np.float32)

            H, W = imgs[0].shape[:2]
            focal = focal * 0.5 * W

            if self.args.half_res:
                H = H // 2
                W = W // 2
                focal = focal / 2.

                imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
                for i, img in enumerate(imgs):
                    imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
                imgs = imgs_half_res

            if args.phase == 'ae':
                imgs = imgs.reshape(((self.n_his + self.n_roll), self.n_view * 2, H, W, 3))
            else:
                imgs = imgs.reshape(((self.n_his + self.n_roll), self.n_view, H, W, 3))
            poses = poses.reshape(((self.n_his + self.n_roll), self.n_view, 4, 4))

            # imgs: (n_his + n_roll) x (n_view * 2) x 3 x H x W
            # poses: (n_his + n_roll) x n_view x 4 x 4
            # actions: (n_his + n_roll) x action_dim
            imgs = torch.FloatTensor(imgs).permute(0, 1, 4, 2, 3)
            poses = torch.FloatTensor(poses)
            actions = torch.FloatTensor(actions)
            hwf = torch.FloatTensor([H, W, focal])


            # print('imgs', imgs.size())
            # print('poses', poses.size())
            # print('actions', actions.size())
            # print('hwf', hwf.size())


            if args.phase == 'dec' and self.dec_eval == False:
                idx = torch.LongTensor(np.array([idx]))
                return imgs, poses, idx, actions, hwf
            else:
                return imgs, poses, actions, hwf


        else:

            state_embeds = []
            actions = []

            for i in range(idx_timestep, idx_timestep + self.n_his + self.n_roll):
                path = os.path.join(args.storef, '%d/%d.p' % (idx_rollout, i))
                state_embeds.append(pickle.load(open(path, 'rb'))['embed'])
                actions.append(self.action[idx_rollout][i])

            state_embeds = np.array(state_embeds).astype(np.float32)
            actions = np.array(actions).astype(np.float32)

            # state_embeds: (n_his + n_roll) x nf_hidden
            # actions: (n_his + n_roll) x action_dim
            state_embeds = torch.FloatTensor(state_embeds)
            actions = torch.FloatTensor(actions)

            '''
            print('state_embeds', state_embeds.size())
            print('actions', actions.size())
            '''

            return state_embeds, actions





