import os
import torch
import numpy as np
import torch.nn.functional as F
import cv2
import pickle

trans_t = lambda t : torch.Tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
    return c2w


def load_mae_data(basedir, half_res=False, testskip=1, episode_num = 100, num_view=0, dataset_type=None):
    assert num_view != 0
    assert dataset_type in ['hammer', 'drawer', 'window', 'push', 'peg', 'stick']
    splits = ['e_%03d'%(i) for i in range(episode_num)]

    metas = {}
    all_imgs = []
    all_poses = []
    all_semantics = []
    all_depths = []
    counts = [0]
    imgs, poses, semantics, depths = None, None, None, None

    for episode in splits:
        with open(os.path.join(basedir, '{}.pkl'.format(episode)), 'rb') as fp:
            metas[episode] = pickle.load(fp)
            meta = metas[episode]
            imgs = meta['observations']
            poses = meta['poses']
            if 'semantic_observations' in meta:
                semantics = meta['semantic_observations']
            depths = meta['depth_observations']
            
            if dataset_type == 'peg':
                view_indices = np.array([0,1,2,3,4,5])
            else:
                view_indices = np.array([0,1,2,3,5,6])
            imgs = (np.array(imgs)[:,view_indices] / 255.).astype(np.float32)
            poses = np.array(poses)[:,view_indices].astype(np.float32)
            if 'semantic_observations' in meta:
                semantics = semantics[:,view_indices].astype(np.float32)/50
            depths = depths[:,view_indices].astype(np.float32)
            
            all_imgs.append(torch.from_numpy(imgs))
            all_poses.append(torch.from_numpy(poses))
            if 'semantic_observations' in meta:
                all_semantics.append(torch.from_numpy(semantics))
            all_depths.append(torch.from_numpy(depths))
        counts.append(len(all_imgs))
        
    imgs = torch.cat(all_imgs, 0).permute(1,0,2,3,4)    # [V, NumEpi*Length, H, W, C]
    poses = torch.cat(all_poses, 0).permute(1,0,2,3)    # [V, NumEpi*Length, 4, 4]
    if semantics is not None:
        semantics = torch.cat(all_semantics, 0).permute(1,0,2,3)
    if depths is not None:
        depths = torch.cat(all_depths, 0).permute(1,0,2,3)   # [V, NumEpi*Length, H, W]
        
    i_split = counts
    
    H, W = imgs[0,0].shape[:2]
    camera_angle_x = float(meta['infos'][0,0]['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)
    
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
    
    if half_res:
        H = H//2
        W = W//2
        focal = focal/2.

        imgs_half_res = np.zeros((imgs.shape[0], imgs.shape[1], H, W, 4))
        for i in range(len(imgs)):
            for j in range(len(imgs[0])):
                imgs_half_res[i,j] = cv2.resize(imgs[i,j], (W, H), interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res
        
    return imgs, poses, render_poses, [H, W, focal], i_split, semantics, depths


def load_test_data(basedir, half_res=False, testskip=1, episode_idx_list = [0], num_view=0, dataset_type=None):
    assert num_view != 0
    assert dataset_type in ['hammer', 'drawer', 'window', 'push', 'peg', 'stick']
    splits = ['e_%03d'%(i) for i in episode_idx_list]

    metas = {}
    all_imgs = []
    all_poses = []
    all_semantics = []
    all_depths = []
    counts = [0]
    imgs, poses, semantics, depths = None, None, None, None

    for episode in splits:
        with open(os.path.join(basedir, '{}.pkl'.format(episode)), 'rb') as fp:
            metas[episode] = pickle.load(fp)
            meta = metas[episode]
            imgs = meta['observations']
            poses = meta['poses']
            if 'semantic_observations' in meta:
                semantics = meta['semantic_observations']
            depths = meta['depth_observations']
            
            if dataset_type == 'peg':
                view_indices = np.array([0,1,2,3,4,5])
            else:
                view_indices = np.array([0,1,2,3,5,6])
            imgs = (np.array(imgs)[:,view_indices] / 255.).astype(np.float32)
            poses = np.array(poses)[:,view_indices].astype(np.float32)
            if 'semantic_observations' in meta:
                semantics = semantics[:,view_indices].astype(np.float32)/50
            depths = depths[:,view_indices].astype(np.float32)
            
            all_imgs.append(torch.from_numpy(imgs))
            all_poses.append(torch.from_numpy(poses))
            if 'semantic_observations' in meta:
                all_semantics.append(torch.from_numpy(semantics))
            all_depths.append(torch.from_numpy(depths))
        counts.append(len(all_imgs))
        
    imgs = torch.cat(all_imgs, 0).permute(1,0,2,3,4)    # [V, NumEpi*Length, H, W, C]
    poses = torch.cat(all_poses, 0).permute(1,0,2,3)    # [V, NumEpi*Length, 4, 4]
    if semantics is not None:
        semantics = torch.cat(all_semantics, 0).permute(1,0,2,3)
    if depths is not None:
        depths = torch.cat(all_depths, 0).permute(1,0,2,3)   # [V, NumEpi*Length, H, W]

    i_split = counts
    
    H, W = imgs[0,0].shape[:2]
    camera_angle_x = float(meta['infos'][0,0]['camera_angle_x'])
    focal = .5 * W / np.tan(.5 * camera_angle_x)
    
    render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
    
    if half_res:
        H = H//2
        W = W//2
        focal = focal/2.

        imgs_half_res = np.zeros((imgs.shape[0], imgs.shape[1], H, W, 4))
        for i in range(len(imgs)):
            for j in range(len(imgs[0])):
                imgs_half_res[i,j] = cv2.resize(imgs[i,j], (W, H), interpolation=cv2.INTER_AREA)
        imgs = imgs_half_res
        
    return imgs, poses, render_poses, [H, W, focal], i_split, semantics, depths