import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from np.np_dataset import Dataset
from np.np_utils import NeuralPullNetwork
import argparse
from pyhocon import ConfigFactory
import os
from shutil import copyfile
import numpy as np
import trimesh
from np_utils.extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
import math
from pytorch3d.ops import knn_points
import mcubes
import warnings


class Runner:
    def __init__(self, path, pointcloud=None, iteration=9000):
        self.device = torch.device('cuda')
        if pointcloud is not None:
            os.makedirs(os.path.join(path, 'query_data'), exist_ok=True)
            dataset_path = os.path.join(path, 'query_data', 'dataset_iter{}.pt'.format(iteration))
            if os.path.exists(dataset_path):
                dataset = torch.load(dataset_path)
            else:
                raise NotImplementedError('This branch not debugged.')
                dataset = Dataset(pointcloud)
                torch.save(dataset, dataset_path)
                self.__setattr__('dataset' + str(i), dataset)
        self.ChamferDisL1 = ChamferDistanceL1().cuda()

        # Networks
        self.__setattr__('sdf_network', NeuralPullNetwork().to(self.device))



    def get_learning_rate_at_iteration(self, iter_step, max_iter=60050):
        warn_up = 1000
        init_lr = 0.001
        lr = (iter_step / warn_up) if iter_step < warn_up else 0.5 * (
                math.cos((iter_step - warn_up) / (max_iter - warn_up) * math.pi) + 1)
        lr = lr * init_lr
        return lr

    def load_checkpoint(self, path, checkpoint_name):
        checkpoint = torch.load(os.path.join(path, 'checkpoints', checkpoint_name),
                                map_location=self.device)
        print(os.path.join(path, 'checkpoints', checkpoint_name))
        sdf_network = self.__getattribute__('sdf_network')
        sdf_network.load_state_dict(checkpoint['sdf_network'])

    def save_checkpoint(self, path, iter_step):
        checkpoint = {}
        sdf_network = self.__getattribute__('sdf_network')
        checkpoint.update({'sdf_network': sdf_network.state_dict()})

        os.makedirs(os.path.join(path, 'checkpoints'), exist_ok=True)
        torch.save(checkpoint, os.path.join(path, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(iter_step)))
