import torch
import os
import pickle
from tqdm import tqdm
import argparse
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from data.util import read_obj
from model.meshunwarpAE import MeshUnwarpAutoencoder
from data.dataset import ObjaverseUnwarp

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--save_path', type=str)
    parser.add_argument('--test_path', type=str)
    parser.add_argument('--save_obj', action='store_true', help='If set, the obj results will be saved.')
    args = parser.parse_args()
    
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    test_dataset = ObjaverseUnwarp(args.test_path, return_pivot=False, augment=False)
    test_dataloader = DataLoader(test_dataset, batch_size = 1, num_workers = 16, shuffle = False, drop_last = True)
    checkpoint = torch.load(args.model_path)
    model =  MeshUnwarpAutoencoder()
    model.load_state_dict(checkpoint['model'])
    print(f"load model success")

    model.eval()
    with torch.no_grad():
        total_val_recon_loss = 0.
        for idx, data in enumerate(tqdm(test_dataloader, desc="Testing Progress", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}]")):
            forward_kwargs = {
                    'vertices': data['vertices'].to(device),
                    'faces': data['faces'].to(device),
                    'uvs': data['uvs'].to(device),
                    'normals': data['normals'].to(device),
                    'curve': data['curve'].to(device),
                    'degree': data['degree'].to(device),
                    'blender_uvs': data["blender_uvs"].to(device),
                }
            test_loss, test_recon_loss, pred_uv, distort_loss, overlap_loss, silhouette_loss = model(**forward_kwargs)
            if args.save_obj:
                for i in range(pred_uv.shape[0]):
                    data_path, uid, vertices, faces, uvs, blender_uvs, normals  = forward_kwargs['data_path'][i] ,forward_kwargs['uid'][i] ,forward_kwargs['vertices'][i], forward_kwargs['faces'][i], forward_kwargs['uvs'][i], forward_kwargs['blender_uvs'][i], forward_kwargs['normals']
                    uv_pred = pred_uv[i]
                        
                    uid = data_path.split("/")[-2]
                    obj_save_path = os.path.join(args.save_path, "pred", uid+'_pred.obj')
                    obj_file = os.path.dirname(obj_save_path)
                    if not os.path.exists(obj_file):
                        os.makedirs(obj_file)
                    source_save_path = os.path.join(args.save_path, "pred", uid+'_pred.obj')
                    with open(source_save_path, "w") as f:
                        for vertex in vertices:
                            f.write(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")
                        for normal in normals:
                            f.write(f"vn {normal[0]} {normal[1]} {normal[2]}\n")
                        for uv in uv_pred:
                            f.write(f"vt {uv[0]} {uv[1]}\n")
                        for face in faces:
                            f.write(f"f {face[0]+1}/{face[0]+1} {face[1]+1}/{face[1]+1} {face[2]+1}/{face[2]+1}\n")            
                total_val_recon_loss += test_recon_loss
            
        total_val_recon_loss /= (len(test_dataloader))

    if args.save_obj:
        print(f"obj_result save to:{args.save_path}")

    print(f'valid recon loss: {total_val_recon_loss}')
