import numpy as np
import os, cv2
import torch
import torch.nn as nn
from run_nerf_helpers import *
from MV_run_nerf import *
from MV_mae_encoder import MaskedViTEncoder
from load_blender import load_test_data
from moviepy.editor import *

#####################################################
# You should choose folder directory where contains ckpt files and args.txt
ckpt_folder_dir = 'ckpt_folder_dir'

episode_index = [119]

# In case of multi-view input (single_view_input=False), you need to specify the reference view
# In case of single-view input (single_view_input=True), any indices in ref_view_index will be ignored.
single_view_input = True
input_view_index = 1    # index of the primary input
ref_view_index = [3, 4] # indices list of the reference views (length of the list should be two)
#####################################################

def args_change_type(args):
    args.use_viewdirs = args.use_viewdirs == 'True'
    args.multires = int(args.multires)
    args.i_embed = int(args.i_embed)
    args.multires_views = int(args.multires_views)
    args.img_size = int(args.img_size)
    args.patch_size = int(args.patch_size)
    args.embed_dim = int(args.embed_dim)
    args.vit_depth = int(args.vit_depth)
    args.vit_num_heads = int(args.vit_num_heads)
    args.num_view = int(args.num_view)
    args.time_interval = int(args.time_interval)
    args.decoder_depth = int(args.decoder_depth)
    args.decoder_num_heads = int(args.decoder_num_heads)
    args.decoder_output_dim = int(args.decoder_output_dim)
    args.netdepth = int(args.netdepth)
    args.netwidth = int(args.netwidth)
    args.perturb = float(args.perturb)
    args.N_samples = int(args.N_samples)
    args.white_bkgd = args.white_bkgd == 'True'
    args.raw_noise_std = float(args.raw_noise_std)
    args.no_ndc = args.no_ndc == 'True'
    args.lindisp = args.lindisp == 'True'
    args.netchunk = int(args.netchunk)
    args.N_importance = int(args.N_importance)
    args.netdepth_fine = int(args.netdepth_fine)
    args.netwidth_fine = int(args.netwidth_fine)
    args.lrate = float(args.lrate)
    args.half_res = args.half_res == 'True'
    args.testskip = int(args.testskip)
    args.episode_num = int(args.episode_num)
    args.chunk = int(args.chunk)
    args.vit_encoder_mlp_dim = int(args.vit_encoder_mlp_dim)
    args.vit_decoder_mlp_dim = int(args.vit_decoder_mlp_dim)
    
    return args

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpts_encoder = [os.path.join(ckpt_folder_dir, f) for f in sorted(os.listdir(ckpt_folder_dir)) if 'tar' in f and 'encoder' in f]
ckpts_nerf = [os.path.join(ckpt_folder_dir, f) for f in sorted(os.listdir(ckpt_folder_dir)) if 'tar' in f and 'encoder' not in f]

assert len(ckpts_encoder) > 0 and len(ckpts_nerf) > 0

# load ckpt
ckpt_encoder_path = ckpts_encoder[-1]
ckpt_nerf_path = ckpts_nerf[-1]
print('Reloading from', ckpt_encoder_path)
print('Reloading from', ckpt_nerf_path)
ckpt_encoder = torch.load(ckpt_encoder_path)
ckpt_nerf = torch.load(ckpt_nerf_path)

# argument load
args = {}
with open(os.path.join(ckpt_folder_dir, 'args.txt')) as arguments:
    for line in arguments:
        line = line.strip()
        line = line.split()
        if line[0] != 'lang_goal':
            assert len(line) == 3
        
            args[line[0]] = line[-1]

import configargparse
parser = configargparse.ArgumentParser()
for key, value in args.items():
    parser.add_argument('--' + key, default=value)

args = parser.parse_args()
args = args_change_type(args)

embed_fn, input_ch = get_embedder(args.multires, device=device, i=args.i_embed)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
    embeddirs_fn, input_ch_views = get_embedder(args.multires_views, device=device, i=args.i_embed)
output_ch = 5 if args.N_importance > 0 else 4
skips = [4]

# load encoder ckpt
latent_embed = MaskedViTEncoder(img_size=args.img_size, patch_size=args.patch_size, embed_dim=args.embed_dim, depth=args.vit_depth, 
                                num_heads=args.vit_num_heads, num_view=args.num_view, device=device, time_interval=args.time_interval,
                                decoder_depth = args.decoder_depth,
                                decoder_num_heads = args.decoder_num_heads, decoder_output_dim = args.decoder_output_dim,
                                batch_size = args.batch_size, vit_encoder_mlp_dim = args.vit_encoder_mlp_dim, 
                                vit_decoder_mlp_dim = args.vit_decoder_mlp_dim,
                                ).to(device)

latent_dim = latent_embed.decoder_output_dim

grad_vars = list(latent_embed.parameters())
latent_embed.load_state_dict(ckpt_encoder, strict=False)

# load nerf ckpt
model = NeRF(D=args.netdepth, W=args.netwidth,
                input_ch=input_ch, output_ch=output_ch, skips=skips,
                input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, 
                latent_dim=latent_dim,
                ).to(device)
grad_vars += list(model.parameters())

model_fine = None
if args.N_importance > 0:
    model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
                        input_ch=input_ch, output_ch=output_ch, skips=skips,
                        input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs, 
                        latent_dim=latent_dim,
                        ).to(device)
    grad_vars += list(model_fine.parameters())

network_query_fn = lambda inputs, viewdirs, network_fn, latent : run_network(inputs, viewdirs, network_fn,
                                                            embed_fn=embed_fn,
                                                            embeddirs_fn=embeddirs_fn,
                                                            netchunk=args.netchunk,
                                                            latent = latent)

start = ckpt_nerf['global_step']

# Load model
model.load_state_dict(ckpt_nerf['network_fn_state_dict'])
if model_fine is not None:
    model_fine.load_state_dict(ckpt_nerf['network_fine_state_dict'])

render_kwargs_train = {
    'network_query_fn' : network_query_fn,
    'perturb' : args.perturb,
    'N_importance' : args.N_importance,
    'network_fine' : model_fine,
    'N_samples' : args.N_samples,
    'network_fn' : model,
    'use_viewdirs' : args.use_viewdirs,
    'white_bkgd' : args.white_bkgd,
    'raw_noise_std' : args.raw_noise_std,
}

# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
    print('Not ndc!')
    render_kwargs_train['ndc'] = False
    render_kwargs_train['lindisp'] = args.lindisp

render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.

# Load data
episode_num = len(episode_index)

K = None
if args.dataset_type in ['hammer', 'drawer', 'window', 'push', 'peg', 'stick']:
    images, poses, render_poses, hwf, i_split, semantics, depths = load_test_data(args.datadir, args.half_res, args.testskip, episode_idx_list=episode_index, num_view=args.num_view, dataset_type=args.dataset_type)
    print('Loaded hammer data', images.shape, render_poses.shape, hwf, args.datadir)
    i_train = i_split

    i_test = [0]

    if args.dataset_type == 'hammer':
        near = 0.02258
        far = 3.
    elif args.dataset_type == 'push':
        near = 0.02258
        far = 3.
    elif args.dataset_type == 'window':
        near = 0.02258
        far = 3.
    elif args.dataset_type == 'stick':
        near = 0.02258
        far = 3.
    elif args.dataset_type == 'peg':
        near = 0.02258
        far = 3.
    elif args.dataset_type == 'drawer':
        near = 0.02343
        far = 3.
    else:
        raise NotImplementedError

    if args.white_bkgd  and images.shape[-1] == 4:
        images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
    else:
        images = images[...,:3]
        
    '''
    original images, poses, render_poses shape
    V, NumEpi*Length, H, W, C = images.shape
    V, NumEpi*Length, 4, 4 = poses.shape
    num_render_poses(40), 4, 4  = render_poses.shape
    '''
    V, NL, H, W, C = images.shape
    
    print(f'episode num : {episode_num}, time interval : {args.time_interval}')
    print("Currently, assume all episode trajectory length is same as 120!")
    episode_length = int(NL/episode_num)
    images = images.reshape(V, episode_num, episode_length, H, W, C) # [V, episode_num, length, H, W, C],
    images = np.transpose(images, (1,2,3,0,4,5)) # [episode_num, length, H, V, W, C]      
else:
    print('Unknown dataset type', args.dataset_type, 'exiting')
    raise NotImplementedError

# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]

if K is None:
    K = np.array([
        [focal, 0, 0.5*W],
        [0, focal, 0.5*H],
        [0, 0, 1]
    ])

global_step = start

bds_dict = {
    'near' : near,
    'far' : far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)

MSEloss = nn.MSELoss()
ssim = SSIM()
import lpips
LPIPS = lpips.LPIPS(net='alex').cuda()

with torch.no_grad():
    test_time_interval = args.time_interval
    
    for epi_index, epi in enumerate(episode_index):
        rgb_MSE_loss_dict = dict()
        rgb_PSNR_dict = dict()
        rgb_SSIM_dict = dict()
        rgb_LPIPS_dict = dict()
        count = 0
        
        testsavedir = os.path.join(ckpt_folder_dir, f'Visualize_Episode{epi}_InputIndex{input_view_index}_RefViewIndex{ref_view_index}_SingleViewInput{single_view_input}')
        os.makedirs(testsavedir, exist_ok=True)       
        for time_index in range(episode_length):
            if time_index < (test_time_interval - 1):
                # if we don't have past {interval} time step images
                lack_interval = test_time_interval -1 - time_index
                lack_images, lack_semantics, _ = get_batch_images(images, 1, time_interval=1, episode_index=epi_index, time_index=0, semantics=semantics)
                lack_images = lack_images.to(device)
                lack_images = torch.tile(lack_images, (1,lack_interval,1,1,1,1))

                test_images, test_semantics, _ = get_batch_images(images, 1, test_time_interval-lack_interval, episode_index = epi_index,
                                                                    time_index = time_index-(test_time_interval-lack_interval-1), semantics=semantics)
                test_images = test_images.to(device)
                test_images = torch.cat((lack_images, test_images), axis=1)
            else:
                # bring past {interval} period images
                test_images, test_semantics, _ = get_batch_images(images, 1, test_time_interval, episode_index = epi_index,
                                                                    time_index = time_index-(test_time_interval-1), semantics=semantics)     
                test_images = test_images.to(device)
            B, T, H, V, W, C = test_images.shape
            test_input_images = test_images[:,:,:,input_view_index:input_view_index+1,:,:]
            test_input_semantics = None
            test_images_for_ViT, reshaped_test_images, test_semantics_for_ViT, reshaped_test_semantics \
                = get_batch_images_for_mae_encoder(test_input_images, test_input_semantics)
            test_images_for_ViT = test_images_for_ViT.float().to(device)
            
            ref_test_images_for_ViT = torch.stack([test_images[:,:,:,index] for index in ref_view_index], dim=3) # [B, T, H, ref_V, W, C]
            ref_test_images_for_ViT, _, _, _ = get_batch_images_for_mae_encoder(ref_test_images_for_ViT) # [B, T, H, ref_V*W, C]
            ref_test_images_for_ViT = ref_test_images_for_ViT.float().to(device)
            
            m = 0.0
            # Forward ViT encoder
            # input view
            test_latent, mask, ids_restore = latent_embed.SinCro_image_encoder(test_images_for_ViT, m, test_time_interval, is_ref=False)
                
            # reference view; masking_ratio = 0
            if single_view_input:
                ref_test_images_for_ViT = torch.unsqueeze(test_images_for_ViT, 3).repeat(1, 1, 1, 2, 1, 1) # torch.Size([1, 3, 128, 2, 128, 3])
                B_ref, T_ref, H_ref, V_ref, W_ref, C_ref = ref_test_images_for_ViT.shape
                ref_test_images_for_ViT = ref_test_images_for_ViT.reshape((B_ref, T_ref, H_ref, V_ref*W_ref, C_ref)) # B, Time, H, VW, C = x.shape
            
            assert ref_test_images_for_ViT.shape[3] != args.img_size
            ref_test_images_for_ViT = torch.split(ref_test_images_for_ViT, args.img_size, dim=3)  # [B, T, H, W, C] * V
            
            ref_time_interval = args.time_interval
            
            ref_test_images_for_ViT = torch.cat(ref_test_images_for_ViT, dim=0) # [VB, T, H, W, C]
            ref, _, _ = latent_embed.SinCro_image_encoder(ref_test_images_for_ViT, 0, ref_time_interval, is_ref=True)
            ref = rearrange(ref[:,1:,:], 'b (ref_T hw) d -> b ref_T hw d', ref_T=ref_time_interval)[:,-1]   # [V*B, H'W', embed]
            ref = rearrange(ref, '(v b) hw d -> b (v hw) d', b=test_latent.shape[0]) # [B, VH'W', embed]
                                
            # Forward ViT decoder
            #[B(1), T,V, H',W', dim] 
            
            test_latent, mask, ids_restore = latent_embed.SinCro_state_encoder(test_latent, ref, mask, ids_restore) 

            #[B(1), TV, dim]
            reshaped_test_latent = test_latent.reshape(B, T, 1, -1) # [B(1), T, V(1), dim]
                
            # all poses used in training [V, 4, 4]
            test_input_poses = poses[:, 0].clone().detach().to(device)
            
            test_recon_view_index = np.arange(V)
            for v in test_recon_view_index:
                test_rgb, test_disp = render_path(test_input_poses[v:v+1], hwf, K, args.chunk, render_kwargs_test, gt_imgs=None, 
                                                                    latent=reshaped_test_latent[:,-1,0], args=args, test_mode=True)
                
                test_rgb = torch.from_numpy(test_rgb).to(device)
                
                if f'rgb_MSE_loss_view{v}_from_latent{input_view_index}' in rgb_MSE_loss_dict:
                    rgb_MSE_loss_dict[f'rgb_MSE_loss_view{v}_from_latent{input_view_index}'] += MSEloss(test_images[0,-1,:,v,:,:], 
                                                                                            test_rgb[0]).item()
                    rgb_PSNR_dict[f'rgb_PSNR_view{v}_from_latent{input_view_index}'] += mse2psnr(MSEloss(test_images[0,-1,:,v,:,:], 
                                                                                            test_rgb[0]), device).item()
                    rgb_SSIM_dict[f'rgb_SSIM_view{v}_from_latent{input_view_index}'] += ssim(test_rgb.permute(0,3,1,2), 
                                                                                        test_images[:,-1,:,v,:,:].permute(0,3,1,2)).item()
                    rgb_LPIPS_dict[f'rgb_LPIPS_view{v}_from_latent{input_view_index}'] += LPIPS(test_images[:,-1,:,v,:,:].permute(0,3,1,2)*2-1, 
                                                                                            test_rgb.permute(0,3,1,2)*2-1).item()
                    count += 1
                else:
                    rgb_MSE_loss_dict[f'rgb_MSE_loss_view{v}_from_latent{input_view_index}'] = MSEloss(test_images[0,-1,:,v,:,:], 
                                                                                            test_rgb[0]).item()
                    rgb_PSNR_dict[f'rgb_PSNR_view{v}_from_latent{input_view_index}'] = mse2psnr(MSEloss(test_images[0,-1,:,v,:,:], 
                                                                                            test_rgb[0]), device).item()
                    rgb_SSIM_dict[f'rgb_SSIM_view{v}_from_latent{input_view_index}'] = ssim(test_rgb.permute(0,3,1,2), 
                                                                                        test_images[:,-1,:,v,:,:].permute(0,3,1,2)).item()
                    rgb_LPIPS_dict[f'rgb_LPIPS_view{v}_from_latent{input_view_index}'] = LPIPS(test_images[:,-1,:,v,:,:].permute(0,3,1,2)*2-1, 
                                                                                            test_rgb.permute(0,3,1,2)*2-1).item()
                    count += 1
                test_rgb = test_rgb.cpu().numpy()
                # [1,H,W,C]
                test_rgb8 = to8b(test_rgb[0])
                
                test_filename = os.path.join(testsavedir, f'Recon_View{v}_by_LatentView{input_view_index}_at_Time{time_index}_Episode{epi}_TimeInterval{test_time_interval}_RefViewIndex{ref_view_index}.png')
                imageio.imwrite(test_filename, test_rgb8)
        
        total_MSE, total_PSNR, total_SSIM, total_LPIPS = 0, 0, 0, 0
        for v in test_recon_view_index:
            rgb_MSE_var_name = f'rgb_MSE_loss_view{v}_from_latent{input_view_index}'
            rgb_PSNR_var_name = f'rgb_PSNR_view{v}_from_latent{input_view_index}'
            rgb_SSIM_var_name = f'rgb_SSIM_view{v}_from_latent{input_view_index}'
            rgb_LPIPS_var_name = f'rgb_LPIPS_view{v}_from_latent{input_view_index}'
            if rgb_MSE_var_name in rgb_MSE_loss_dict:
                assert count == episode_length * len(test_recon_view_index)
                rgb_MSE_loss_dict[rgb_MSE_var_name] /= episode_length
                rgb_PSNR_dict[rgb_PSNR_var_name] /= episode_length
                rgb_SSIM_dict[rgb_SSIM_var_name] /= episode_length
                rgb_LPIPS_dict[rgb_LPIPS_var_name] /= episode_length
                total_MSE += rgb_MSE_loss_dict[rgb_MSE_var_name]
                total_PSNR += rgb_PSNR_dict[rgb_PSNR_var_name]
                total_SSIM += rgb_SSIM_dict[rgb_SSIM_var_name]
                total_LPIPS += rgb_LPIPS_dict[rgb_LPIPS_var_name]
            with open(os.path.join(testsavedir, 'MSE_loss.txt'), 'a') as txt:
                txt.write(f'rgb_MSE_loss_view{v}_from_latent{input_view_index}_at_episode{epi} : {rgb_MSE_loss_dict[rgb_MSE_var_name]:.4f}\n')
                txt.write(f'rgb_PSNR_view{v}_from_latent{input_view_index}_at_episode{epi} : {rgb_PSNR_dict[rgb_PSNR_var_name]:.4f}\n')
                txt.write(f'rgb_SSIM_view{v}_from_latent{input_view_index}_at_episode{epi} : {rgb_SSIM_dict[rgb_SSIM_var_name]:.4f}\n')
                txt.write(f'rgb_LPIPS_view{v}_from_latent{input_view_index}_at_episode{epi} : {rgb_LPIPS_dict[rgb_LPIPS_var_name]:.4f}\n')
        assert len(test_recon_view_index) == 6
        total_MSE /= len(test_recon_view_index)
        total_PSNR /= len(test_recon_view_index)
        total_SSIM /= len(test_recon_view_index)
        total_LPIPS /= len(test_recon_view_index)
        with open(os.path.join(testsavedir, 'MSE_loss.txt'), 'a') as txt:
            txt.write(f'total_MSE_loss_from_latent{input_view_index}_at_episode{epi} : {total_MSE:.4f}\n')
            txt.write(f'total_PSNR_from_latent{input_view_index}_at_episode{epi} : {total_PSNR:.4f}\n')
            txt.write(f'total_SSIM_from_latent{input_view_index}_at_episode{epi} : {total_SSIM:.4f}\n')
            txt.write(f'total_LPIPS_from_latent{input_view_index}_at_episode{epi} : {total_LPIPS:.4f}\n')
    
        # load RGB images
        file_list = os.listdir(testsavedir) 
        for v in test_recon_view_index:
            rgb_arr = []
            for i in range(episode_length):
                rgb_filename_i = [os.path.join(testsavedir, f) for f in file_list
                                if f'Recon_View{v}_by_LatentView{input_view_index}_at_Time{i}_Episode' in f]
                
                assert len(rgb_filename_i) == 1
                rgb = cv2.imread(rgb_filename_i[0])
                rgb_arr.append(rgb)
                
            # Write video file
            rgb_h, rgb_w = rgb_arr[0].shape[:2]
            rgb_video_name = os.path.join(testsavedir, f'Recon_View{v}_by_LatentView{input_view_index}_at_Episode{epi}')
            video_format = '.avi'
            fps = 20
            rgb_out = cv2.VideoWriter(rgb_video_name + video_format, cv2.VideoWriter_fourcc(*'MJPG'), fps, (rgb_w, rgb_h))
                
            for i in range(len(rgb_arr)):
                rgb_out.write(rgb_arr[i])
            
            rgb_out.release()
            VideoFileClip(rgb_video_name + video_format).speedx(1).write_gif(rgb_video_name + '.gif')
        
print('end')