import os
import re
import torch
import numpy as np
import imageio
import json
import torch.nn.functional as F
import cv2
from pdb import set_trace as bp

trans_t = lambda t: torch.Tensor([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, t],
    [0, 0, 0, 1]]).float()

rot_phi = lambda phi: torch.Tensor([
    [1, 0, 0, 0],
    [0, np.cos(phi), -np.sin(phi), 0],
    [0, np.sin(phi), np.cos(phi), 0],
    [0, 0, 0, 1]]).float()

rot_theta = lambda th: torch.Tensor([
    [np.cos(th), 0, -np.sin(th), 0],
    [0, 1, 0, 0],
    [np.sin(th), 0, np.cos(th), 0],
    [0, 0, 0, 1]]).float()


def pose_spherical(theta, phi, radius, rotZ=True, wx=0.0, wy=0.0, wz=0.0):
    # spherical, rotZ=True: theta rotate around Z; rotZ=False: theta rotate around Y
    # wx,wy,wz, additional translation, normally the center coord.
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    if rotZ:  # swap yz, and keep right-hand
        c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w

    ct = torch.Tensor([
        [1, 0, 0, wx],
        [0, 1, 0, wy],
        [0, 0, 1, wz],
        [0, 0, 0, 1]]).float()
    c2w = ct @ c2w

    return c2w


def load_pinf_frame_data(basedir, half_res=False, split='train', frame_num_cutoff=-1):
    # frame data
    all_imgs = []
    all_poses = []
    all_hwf = []
    with open(os.path.join(basedir, 'info.json'), 'r') as fp:
        # read render settings
        meta = json.load(fp)
        near = float(meta['near'])
        far = float(meta['far'])
        radius = (near + far) * 0.5
        phi = float(meta['phi'])
        rotZ = (meta['rot'] == 'Z')
        r_center = np.float32(meta['render_center'])

        # read scene data
        voxel_tran = np.float32(meta['voxel_matrix'])
        voxel_tran = np.stack([voxel_tran[:, 2], voxel_tran[:, 1], voxel_tran[:, 0], voxel_tran[:, 3]],
                              axis=1)  # swap_zx
        voxel_scale = np.broadcast_to(meta['voxel_scale'], [3])

        # read video frames
        # all videos should be synchronized, having the same frame_rate and frame_num
        if split == 'all':
            video_list = []
            video_list.extend(meta['train_videos'])
            video_list.extend(meta['test_videos'])
        else:
            video_list = meta[split + '_videos'] if (split + '_videos') in meta else meta['train_videos'][0:1]

        max_frame_num = 0
        for video_id, train_video in enumerate(video_list):
            imgs = []

            f_name = os.path.join(basedir, train_video['file_name'])
            if frame_num_cutoff <= 0:
                frame_num = train_video['frame_num']
            else:
                frame_num = frame_num_cutoff
            max_frame_num = max(max_frame_num, frame_num)
            reader = imageio.get_reader(f_name, "ffmpeg")
            for frame_i in range(frame_num):
                reader.set_image_index(frame_i)
                frame = reader.get_next_data()
              
                imgs.append(frame)
            
            H, W = frame.shape[:2]
            camera_angle_x = float(train_video['camera_angle_x'])
            Focal = .5 * W / np.tan(.5 * camera_angle_x)
            reader.close()
            imgs = (np.float32(imgs) )

            if half_res:
                H = H // 2
                W = W // 2
                Focal = Focal / 2.

                imgs_half_res = np.zeros((imgs.shape[0], H, W, imgs.shape[-1]))
                for i, img in enumerate(imgs):
                    imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
                imgs = imgs_half_res 
            
            hwf =[H, W, Focal]
            all_hwf.append(hwf)
            all_imgs.append(imgs/ 255.)
            all_poses.append(np.array(
                train_video['transform_matrix_list'][frame_i]
                if 'transform_matrix_list' in train_video else train_video['transform_matrix']
            ).astype(np.float32))
    # bp()
    imgs = np.stack(all_imgs, 0) # [V, T, H, W, 3]
    imgs = np.transpose(imgs, [1, 0, 2, 3, 4])  # [T, V, H, W, 3]
    poses = np.stack(all_poses, 0)  # [V, 4, 4]
    hwfs = np.stack(all_hwf, 0)
    # hwf = np.float32([H, W, Focal])

    # set render settings:
    sp_n = 120  # an even number! TODO:
    sp_poses = [
        pose_spherical(angle, phi, radius, rotZ, r_center[0], r_center[1], r_center[2])
        for angle in np.linspace(-180, 180, sp_n + 1)[:-1]
    ]
    render_poses = torch.stack(sp_poses, 0)  # [sp_poses[36]]*sp_n, for testing a single pose
    # render_timesteps = np.arange(sp_n) / (sp_n - 1)
    render_timesteps = np.arange(max_frame_num) / (sp_n - 1) # TODO:w
    return imgs, poses, hwfs, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far, video_list


# def load_combine_data(args, split='train', num_frameFM=5, frame_num_cutoff=-1):
#     fm_infer_dir= args.out_frame_fm_dir
#     imgs, poses, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far, video_list= \
#         load_pinf_frame_data(args.datadir, args.half_res, split=split, frame_num_cutoff=frame_num_cutoff)
#     # [T, V, H, W, 3]
#     all_imgs = []
    
#     max_frame_num = 0
#     for video in video_list:
#         if frame_num_cutoff <= 0:
#             frame_num = video['frame_num']
#         else:
#             frame_num = frame_num_cutoff
#         max_frame_num = max(max_frame_num + num_frameFM, frame_num + num_frameFM)
#         video_id = video['file_name'].split('.')[0]
#         imgs_fm = []
#         out_frame_dir = sorted(os.listdir(os.path.join(fm_infer_dir, video_id)), key=lambda x: int(re.findall("\d+", x)[1]))
#         print(out_frame_dir)
#         is_img = lambda x: x.endswith(('png', 'jpeg', 'jpg'))
#         for image_name in out_frame_dir:
#             if is_img(image_name):
#                 img_path = os.path.join(fm_infer_dir, video_id, image_name)
#                 frame = imageio.imread(img_path)
#                 H, W, = frame.shape[:2]
#                 imgs_fm.append(frame)
#         imgs_fm = (np.float32(imgs_fm) / 255.)
#         H, W = imgs.shape[2], imgs.shape[3]
#         imgs_half_res = np.zeros((imgs_fm.shape[0], H, W, imgs.shape[-1]))
#         for i, img in enumerate(imgs_fm):
#             imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
#         imgs_fm = imgs_half_res
#         all_imgs.append(imgs_fm)
#     fm_imgs_np = np.stack(all_imgs, 0) # [V, T, H, W, 3]
#     fm_imgs_np = np.transpose(fm_imgs_np, [1, 0, 2, 3, 4])  # [T, V, H, W, 3]
#     fm_imgs_np = fm_imgs_np[:num_frameFM]
#     print(imgs.shape, fm_imgs_np.shape, poses.shape)
#     sp_n = 120  ## all frames of a video
#     render_timesteps = np.arange(max_frame_num) / (sp_n - 1)
#     imgs_new = np.concatenate((imgs, fm_imgs_np), axis = 0)
#     return imgs_new, poses, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far


def load_combine_data(args, split='train', num_frameFM=5, frame_num_cutoff=-1):
    # fm_infer_dir = args.out_frame_fm_dir
    out_frame_fm_dir = args.out_frame_fm_dir
    fm_name = out_frame_fm_dir.split('/')[-1]

    fm_infer_dir = os.path.join(out_frame_fm_dir, 'frames')
    # fm_infer_dir = './logs/exp/basic_config/inf_'+fm_name+'/output_frames'
    imgs, poses, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far, video_list= \
        load_pinf_frame_data(args.datadir, args.half_res, split=split, frame_num_cutoff=frame_num_cutoff)
    if num_frameFM==0:
        return imgs, poses, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far
    # [T, V, H, W, 3]
    all_imgs = []
    
    for video in video_list:
        max_frame_num = 0
        if frame_num_cutoff <= 0:
            frame_num = video['frame_num']
        else:
            frame_num = frame_num_cutoff 
        max_frame_num = max(max_frame_num , frame_num)+ num_frameFM
        video_id = video['file_name'].split('.')[0]
        imgs_fm = []
        out_frame_dir = sorted(os.listdir(os.path.join(fm_infer_dir, video_id)), key=lambda x: int(re.findall("\d+", x)[-1]))
        # print(out_frame_dir)
        is_img = lambda x: x.endswith(('png', 'jpeg', 'jpg'))
        image_names = []
        for image_name in out_frame_dir:
            if is_img(image_name):
                # print(re.findall("\d+", image_name)[-1])
                i_frame = int(re.findall("\d+", image_name)[-1]) 
                if frame_num <= i_frame < (frame_num + num_frameFM):
                    # print(int(re.findall("\d+", image_name)[-1]),  max_frame_num)
                    img_path = os.path.join(fm_infer_dir, video_id, image_name)
                    frame = imageio.imread(img_path)
                    H, W, = frame.shape[:2]
                    imgs_fm.append(frame)
                    image_names.append(image_name)
        print(image_names)
        imgs_fm = (np.float32(imgs_fm) / 255.)
        H, W = imgs.shape[2], imgs.shape[3]
        imgs_half_res = np.zeros((imgs_fm.shape[0], H, W, imgs.shape[-1]))
        for i, img in enumerate(imgs_fm):
            imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
        imgs_fm = imgs_half_res
        all_imgs.append(imgs_fm)
    fm_imgs_np = np.stack(all_imgs, 0) # [V, T, H, W, 3]
    fm_imgs_np = np.transpose(fm_imgs_np, [1, 0, 2, 3, 4])  # [T, V, H, W, 3]
    fm_imgs_np = fm_imgs_np[:num_frameFM]
    # print(imgs.shape, fm_imgs_np.shape, poses.shape)
    sp_n = 120  ## all frames of a video
    render_timesteps = np.arange(max_frame_num) / (sp_n - 1)
    # print(max_frame_num)
    imgs_new = np.concatenate((imgs, fm_imgs_np), axis = 0)
    return imgs_new, poses, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far


def load_fm_features(args):
    out_frame_fm_dir = args.out_frame_fm_dir
    features_dir = os.path.join(out_frame_fm_dir, 'all_features.npy')
    features = np.load(features_dir)
    return features