import os
import pickle
import argparse
import mediapy as media
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.autoencoder_kl_cogvideox import AutoencoderKLCogVideoXCustom as AutoencoderKLCogVideoX
from models.unet3d import UNet3D
from models.unet import UNet
from models.mlp import Mlp
from models.resnet import VAEEncoderadaptor, VAEDecoderadaptor
from utils.gaussian_splatting import gs_render_batch

import open3d as o3d


def save_pc(pc, color, path):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pc[0, :, 0, :, :].reshape(3, -1).T)
    pcd.colors = o3d.utility.Vector3dVector(color[0] / 255)
    o3d.io.write_point_cloud(path, pcd)


def parse_args():
    parser = argparse.ArgumentParser(description="Inference reconstructed results of VAE")

    parser.add_argument("--vae_model_path", type=str, default="/xxx")
    parser.add_argument("--ckpt_dir", type=str, default="/xxx")
    parser.add_argument("--validation_samples", type=str, default=None, nargs="+")
    parser.add_argument("--num_frames", type=int, default=17)
    parser.add_argument("--prompt_type", type=str, default="resnet", choices=["unet3d", "unet", "mlp","resnet"])
    parser.add_argument("--normalize_track", action="store_true", help="Whether to normalize the track coordinates")
    parser.add_argument("--normalize_track_first_frame", action="store_true", help="Whether to normalize the track coordinates.")
    
    args = parser.parse_args()

    return args


def main(args):
    if args.prompt_type == "unet3d":
        encoder_prompt = UNet3D(in_channels=3, out_channels=3, final_activation="sigmoid", upsample="trilinear", f_maps=[32, 64, 128, 256])
        decoder_prompt = UNet3D(in_channels=3, out_channels=3, final_activation=None, upsample="trilinear", f_maps=[32, 64, 128, 256])
    elif args.prompt_type == "unet":
        encoder_prompt = UNet(in_channels=3, out_channels=3, final_activation="sigmoid")
        if args.normalize_track:
            decoder_prompt = UNet(in_channels=3, out_channels=3, final_activation="tanh")
        else:
            decoder_prompt = UNet(in_channels=3, out_channels=3, final_activation=None)
    elif args.prompt_type == "mlp":
        encoder_prompt = Mlp(in_features=3, out_features=3, hidden_features=256, act_layer=nn.GELU, drop=0.0, final_activation="sigmoid")
        if args.normalize_track:
            decoder_prompt = Mlp(in_features=3, out_features=3, hidden_features=256, act_layer=nn.GELU, drop=0.0, final_activation="tanh")
        else:
            decoder_prompt = Mlp(in_features=3, out_features=3, hidden_features=256, act_layer=nn.GELU, drop=0.0, final_activation=None)
    elif args.prompt_type == "resnet":
        encoder_prompt = VAEEncoderadaptor()
        decoder_prompt = VAEDecoderadaptor()
    else:
        raise NotImplementedError(f"Prompt type {args.prompt_type} not implemented!")
    vae = AutoencoderKLCogVideoX.from_pretrained(args.vae_model_path, subfolder="vae")
    encoder_ckpt = torch.load(os.path.join(args.ckpt_dir, "encoder_prompt", "pytorch_model.bin"))
    encoder_prompt.load_state_dict(encoder_ckpt)
    decoder_ckpt = torch.load(os.path.join(args.ckpt_dir, "decoder_prompt", "pytorch_model.bin"))
    decoder_prompt.load_state_dict(decoder_ckpt)
    if os.path.exists(os.path.join(args.ckpt_dir, "vae", "pytorch_model.bin")):
        vae_ckpt = torch.load(os.path.join(args.ckpt_dir, "vae", "pytorch_model.bin"))
        vae.load_state_dict(vae_ckpt)
    
    encoder_prompt.eval()
    decoder_prompt.eval()
    vae.eval()
    encoder_prompt.to("cuda")
    decoder_prompt.to("cuda")
    vae.to("cuda")
    
    scale = 0.0001
    H, W = [384, 512]
    # H_ori, W_ori = [540, 960]
    H_ori, W_ori = [720, 960]
    extrinsic = torch.Tensor([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]
    ])
    if W_ori / W > H_ori / H:   # W占满, H更扁
        fx = 1
        fy = W_ori / H_ori / (W / H)
    else:
        fy = 1
        fx = H_ori / W_ori / (H / W)
    intrinsics = torch.Tensor([
        [fx, 0, 0.5],
        [0, fy, 0.5],
        [0, 0, 1]
    ])
    # for i, data_path in enumerate(args.validation_samples):
    args.validation_samples = "/xxx/dt3d_render"
    for i, data_path in enumerate(os.listdir(args.validation_samples)):
        data_path = os.path.join(args.validation_samples, data_path)
        with open(data_path, "rb") as f:
            data = pickle.load(f)
        if isinstance(data["coords"], np.ndarray):
            targets = torch.from_numpy(data["coords"]).to("cuda").float()
        else:
            targets = data["coords"].to("cuda").float()
        if len(targets.shape) == 4:
            targets = targets[0]
        first_frame = targets[0:1, :, :].clone()
        if not args.normalize_track:
            print(abs(targets).max(),abs(targets).mean())
            targets = targets[:, :, :] - first_frame
        
        targets = targets.reshape(targets.shape[0], H, W, 3)[:args.num_frames, :, : ,:].permute(3, 0, 1, 2).unsqueeze(0)
        
        if args.normalize_track_first_frame:
            # Get xyz coordinates of frame 0
            frame0 = first_frame.reshape(H, W, 3).permute(2, 0, 1)  # [3, H, W]

            # Compute max and min for each channel (x, y, z)
            max_vals = frame0.view(3, -1).max(dim=1)[0]  # [3]
            min_vals = frame0.view(3, -1).min(dim=1)[0] # [3]
            diff = (max_vals - min_vals).max().repeat(3)  # [3]
            
            # Avoid division by zero
            diff[diff == 0] = 1.0
            targets = (targets[ :, :args.num_frames, :, :]   / diff.view(3,1, 1, 1))  # Normalize the first frame
        if args.normalize_track:
            targets = targets / torch.abs(targets).max()
        first_frame = first_frame.reshape(first_frame.shape[0], H, W, 3).permute(3, 0, 1, 2).unsqueeze(0)
        if isinstance(data["colors"], np.ndarray):
            color = torch.from_numpy(data["colors"]).to("cuda").float().unsqueeze(0)
        else:
            color = data["colors"].to("cuda").float()
            if len(color.shape) == 2:
                color = color.unsqueeze(0)
            
        with torch.no_grad():
            pseudo_video = encoder_prompt(targets) * 2 - 1
            recon_video = vae(pseudo_video).sample
            recon_flow = decoder_prompt(recon_video)
            if args.normalize_track_first_frame:
                print(targets.max(),targets.min())
                print(((targets-recon_flow)**2).mean())
                recon_flow = recon_flow * diff.view(3, 1, 1, 1)  # Denormalize the first frame
                targets = targets * diff.view(3, 1, 1, 1)  # Denormalize the first frame

            render_pc_gt = targets if args.normalize_track else targets+first_frame
            rendered_videos_gt = gs_render_batch(
                intrinsics,
                extrinsic,
                [H, W],
                render_pc_gt,
                torch.Tensor([scale, scale, scale]),
                torch.Tensor([0.0, 0.0, 0.0, 1.0]),
                color/255,
                torch.ones((H*W,))
            )
            render_pc_pred = recon_flow if args.normalize_track else  recon_flow+first_frame
            rendered_videos = gs_render_batch(
                intrinsics,
                extrinsic,
                [H, W],
                render_pc_pred,
                torch.Tensor([scale, scale, scale]),
                torch.Tensor([0.0, 0.0, 0.0, 1.0]),
                color/255,
                torch.ones((H*W,))
            )
            rendered_videos_gt = rendered_videos_gt.squeeze(0).permute(0, 2, 3, 1).detach().cpu().numpy() * 255
            rendered_videos = rendered_videos.squeeze(0).permute(0, 2, 3, 1).detach().cpu().numpy() * 255
        if data_path.split("/")[-1] == "dense_3d_track.pkl":
            name = data_path.split("/")[-2]
        else:
            name = data_path.split("/")[-1].split("_")[0]
        if not os.path.exists(os.path.join(args.ckpt_dir, "vis")):
            os.makedirs(os.path.join(args.ckpt_dir, "vis"))
        # save_pc(targets.detach().cpu().numpy(), color.detach().cpu().numpy(), os.path.join(args.ckpt_dir, "vis", "{}_normalized.ply".format(name)))
        # save_pc(targets_original.detach().cpu().numpy(), color.detach().cpu().numpy(), os.path.join(args.ckpt_dir, "vis", "{}_ori.ply".format(name)))
        media.write_video(os.path.join(args.ckpt_dir, "vis", "{}_gt.mp4".format(name)), rendered_videos_gt.astype(np.uint8), fps=8)
        media.write_video(os.path.join(args.ckpt_dir, "vis", "{}_pred.mp4".format(name)).format(name), rendered_videos.astype(np.uint8), fps=8)
        print(name)
            

if __name__ == "__main__":
    args = parse_args()
    main(args)
    