import os
import sys
import numpy as np
import torch
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append("{0}".format(os.path.dirname(BASE_DIR)))
from GANet import Model


def test():
    save_completion_path = "{0}/vis/fine".format(BASE_DIR)
    save_input_path = "{0}/vis/input".format(BASE_DIR)
    save_gt_path = "{0}/vis/gt".format(BASE_DIR)
    os.makedirs(save_completion_path, exist_ok=True)
    os.makedirs(save_input_path, exist_ok=True)
    os.makedirs(save_gt_path, exist_ok=True)

    model_file = "{0}/model/best_cd_t_network.pth".format(BASE_DIR)
    net = Model()
    net.eval()

    net.load_state_dict(torch.load(model_file)['net_state_dict'])

    def run(inputs, gt, sample):
        with torch.no_grad():
            inputs = inputs.float()
            coarse_raw, coarse, coarse_high, fine = net(inputs)
            np.savetxt("{0}/output/{1}.txt".format(BASE_DIR, sample), fine[0].cpu().numpy(), delimiter=",")
            pic = "{0}.png".format(sample)
            plot_single_pcd(fine[0].cpu().numpy(), os.path.join(save_completion_path, pic))
            plot_single_pcd(gt.transpose(2,1)[0], os.path.join(save_gt_path, pic))
            plot_single_pcd(inputs.transpose(2,1)[0].cpu().numpy(), os.path.join(save_input_path, pic))

    for _, _, files in os.walk("{0}/demo_data/partial_input".format(BASE_DIR)):
        for file in files:
            sample = file.replace(".txt", "")
            sample_file = "{0}/demo_data/partial_input/{1}".format(BASE_DIR, file)
            gt_file = "{0}/demo_data/gt/{1}".format(BASE_DIR, file)
            inputs = read_data(sample_file)
            gt = read_data(gt_file)
            run(inputs, gt, sample)

def read_data(file):
    data = np.loadtxt(file, delimiter=",")
    data = torch.from_numpy(data).unsqueeze(0).transpose(2,1).contiguous()
    return data
    

if __name__ == "__main__":
    test()




