import argparse
import math
import os
import cv2

import numpy as np
import imageio.v3 as iio

import torch

import trimesh

import models_class_cond, models_ae
from models.models_lp import KLAutoEncoder
from datasets.shapenet import category_ids
import util.misc as misc

from pathlib import Path
from loguru import logger
from omegaconf import OmegaConf
from einops import repeat, rearrange

from util.geom_utils import build_grid2D, get_W2C_uniform, fusion
from gs import GaussianModel, gs_render

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


if __name__ == "__main__":

    parser = argparse.ArgumentParser('', add_help=False)
    # parser.add_argument('--ae', type=str, required=True) # 'kl_d512_m512_l16'
    parser.add_argument("--config_ae", required=True, help='config file path fo ae')
    parser.add_argument('--ae_pth', type=str, required=True) # 'output/ae/kl_d512_m512_l16/checkpoint-199.pth'
    parser.add_argument('--dm', type=str, required=True) # 'kl_d512_m512_l16_edm'
    parser.add_argument('--dm_pth', type=str, required=True) # 'output/uncond_dm/kl_d512_m512_l16_edm/checkpoint-999.pth'
    args = parser.parse_args()
    print(args)

    dump_dir = os.path.dirname(args.dm_pth).split('ckpt')[0] + "/results/"
    print(dump_dir)
    if not os.path.exists(dump_dir):
        os.makedirs(dump_dir)

    logger.add(f"{dump_dir}/log_sample.txt", level="DEBUG")
    git_env, run_command = misc.get_run_env()
    logger.info(git_env)
    logger.info(run_command)

    config_ae = OmegaConf.load(args.config_ae)
    OmegaConf.resolve(config_ae)

    device = torch.device('cuda:0')

    # ae = models_ae.__dict__[args.ae]()
    ae = KLAutoEncoder(config_ae)
    ae.eval()
    ae.load_state_dict(torch.load(args.ae_pth)['model'])
    ae.to(device)

    model = models_class_cond.__dict__[args.dm]()
    model.eval()

    model.load_state_dict(torch.load(args.dm_pth)['model'])
    model.to(device)

    N = config_ae.model.num_lp * 4
    num_samples = config_ae.loss.num_samples
    queries_grid = build_grid2D(vmin=0., vmax=1., res=int(np.sqrt(num_samples)), device=device).reshape(-1, 2)
    queries_grid = repeat(queries_grid, 's d -> b n s d', b=1, n=N)  # [1, N, S, 2]

    gaussian_model = GaussianModel(config_ae.model.gs)

    total = 2500  # 5000
    iters = 100

    # render settings
    all_eval_pose = torch.load('./all_eval_pose.pt')  # [22, 25]
    # intrinsic
    assert all_eval_pose[0][16:].reshape(3, 3)[0, 0] == 525 / 512
    fov_rad = 2 * np.arctan(512 / (2 * 525))
    # extrinsic
    cam2worlds = all_eval_pose[:, :16].reshape(1, -1, 4, 4).cuda()  # [1, V, 4, 4]
    world2cams = torch.linalg.inv(cam2worlds)

    W2C_uniform = get_W2C_uniform(n_views=100, radius=1.2, device=device)  # [100, 4, 4]

    assert len(config_ae.dataset.categories) == 1
    with torch.no_grad():
        for cid in config_ae.dataset.categories:
            category_id = category_ids[cid]
            print(category_id)
            for i in range(total//iters):
                sampled_array = model.sample(cond=torch.Tensor([category_id]*iters).long().to(device), batch_seeds=torch.arange(i*iters, (i+1)*iters).to(device)).float()

                print(sampled_array.shape, sampled_array.max(), sampled_array.min(), sampled_array.mean(), sampled_array.std())

                for j in range(sampled_array.shape[0]):
                    
                    outputs_dict = ae.decode(sampled_array[j:j+1], queries_grid)
                    gaussians_render = gaussian_model(outputs_dict['gs'])  # [B, N, S, 14], the activated values

                    # render
                    bg_color = torch.ones(3, dtype=torch.float32, device=device)

                    render_dict = gs_render(
                        gaussians=rearrange(gaussians_render,  'b n s d -> b (n s) d'),
                        R=world2cams[:, :, :3, :3],
                        T=world2cams[:, :, :3, 3],
                        fov_rad=fov_rad,
                        output_size=512,
                        bg_color=bg_color,
                    )

                    if not os.path.exists(f'{dump_dir}/gaussians/{category_id:02d}'):
                        os.makedirs(f'{dump_dir}/gaussians/{category_id:02d}')
                    np.save(f'{dump_dir}/gaussians/{category_id:02d}/{i*iters+j:05d}.npy', gaussians_render.detach().cpu().numpy())

                    if not os.path.exists(f"{dump_dir}/images/{category_id:02d}"):
                        os.makedirs(f"{dump_dir}/images/{category_id:02d}")
                    assert render_dict['images'].shape[0] == 1
                    for v, img in enumerate(render_dict['images'][0]):
                        img = img.permute(1, 2, 0).detach().cpu().numpy()  # [H, W, 3]
                        img = cv2.resize(img,  (128, 128), interpolation=cv2.INTER_LANCZOS4)
                        img = np.clip(img, 0, 1)
                        img_uint8 = (img * 255).astype(np.uint8)
                        iio.imwrite(f"{dump_dir}/images/{category_id:02d}/{i*iters+j:05d}_{v}.png", img_uint8)


