import torch
import numpy as np
import torch.nn.functional as F
from tools import builder
from utils.logger import *
from utils import misc


def run_net(args, config, train_writer=None, val_writer=None):
    logger = get_logger(args.log_name)
    text_query = args.query
    pts_path = args.pts_path

    # build model
    base_model = builder.model_builder(config.model)
    if args.use_gpu:
        base_model.to(args.local_rank)
    base_model.zero_grad()
    base_model.eval()

    mask_axis = 1
    mask_direct = True
    mask_pos = 0

    ori_points = torch.from_numpy(np.load(pts_path)).cuda()
    dims = len(ori_points.shape)
    assert dims == 2 or dims == 3
    if dims == 2:
        ori_points = ori_points.unsqueeze(0)

    B, N, C = ori_points.shape
    if N > config.model.npoints:
        ori_points = misc.fps(ori_points.float(), config.model.npoints)
    else:
        ori_points = F.interpolate(ori_points.transpose(1, 2), size=config.model.npoints, mode='linear').transpose(1, 2)

    with torch.no_grad():
        _, edited_points = base_model.part_generation(ori_points, mask_axis, mask_direct, mask_pos, text_query)

    edited_data = {
        'ori_points': ori_points.cpu().numpy(),
        'edited_points': edited_points.cpu().numpy()
    }
    np.save('edited_data.npy', edited_data)
    print('Successfully save completion data to edited_data.npy')
