import torch
import numpy as np
import sys
sys.path.append("../")
import tensorboardX
import matplotlib.pyplot as plt
import os
import time
import argparse
import logging
from models.SDF_model import ImplicitNetwork
from utils import get_grid_new, get_logger
from skimage import measure
import os
from utils import set_seed_everywhere
import plotly.graph_objs as go
import plotly.offline as offline
import trimesh


def get_threed_scatter_trace(points, caption=None, colorscale=None, color=None):
    trace = go.Scatter3d(
        x=points[:, 0].cpu(),
        y=points[:, 1].cpu(),
        z=points[:, 2].cpu(),
        mode='markers',
        name='projection',
        marker=dict(
            size=3,
            line=dict(
                width=2,
            ),
            opacity=0.9,
            colorscale=colorscale,
            showscale=True,
            color=color,
        ), text=caption)
    return trace


def get_surface_trace(points, model, resolution):
    trace = []
    meshexport = None

    grid = get_grid_new(points, resolution)
    z = []
    for i, pnts in enumerate(torch.split(grid['grid_points'], 10000, dim=0)):
        z.append(model(pnts).detach().cpu().numpy())
    z = np.concatenate(z, axis=0).astype(np.float64)

    if (not (np.min(z) > 0. or np.max(z) < 0.)):

        verts, faces, normals, values = measure.marching_cubes(
            volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
                             grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
            level=0.,
            spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
                     grid['xyz'][1][2] - grid['xyz'][1][1],
                     grid['xyz'][2][2] - grid['xyz'][2][1]))

        verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
        meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values)

        I, J, K = ([triplet[c] for triplet in faces] for c in range(3))
        trace.append(go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
                               i=I, j=J, k=K, name='', color='orange', alphahull=5, opacity=0.4))
    return {"mesh_trace": trace, "mesh_export": meshexport}


class ConstraintTrainer:
    def __init__(self, args):
        self.args = args
        self.dataset_name = self.args.dataset_name

        set_seed_everywhere(args.seed)
    
        exps_folder = f'./constraint/exp/{self.dataset_name}'
        if not os.path.exists(exps_folder):
            os.mkdir(exps_folder)
        timestamp = time.strftime('%m%d-%H-%M-%S')
        self.expdir = os.path.join(exps_folder, self.args.save_prefix+'-'+timestamp)
        if not os.path.exists(self.expdir):
            os.mkdir(self.expdir)

        logging = get_logger(self.expdir)

        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        logging.info(f"Found {os.cpu_count()} total number of CPUs.")
        if self.device == torch.device('cuda'):
            logging.info(f"Found {torch.cuda.device_count()} CUDA devices.")
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                logging.info(f"{props.name} \t Memory: {props.total_memory / (1024 ** 3):.2f}GB")
        logging.info("Using device: {}".format(self.device))
        
        self.dataset = self.get_data()
        self.tb_logger = tensorboardX.SummaryWriter(log_dir=self.expdir)

        self.alpha_extend_sample = self.args.alpha_extend_sample
        self.hidden_layers = [128, 128, 128]
            
    def get_data(self):
        if self.dataset_name == 'bunny_whole':
            npy_path = f"./data/bunny/bunny_whole.npy"
            point_set = torch.tensor(np.load(npy_path), dtype=torch.float32)
        elif self.dataset_name == 'spot_whole':
            npy_path = f"./data/spot/spot_whole.npy"
            point_set = torch.tensor(np.load(npy_path), dtype=torch.float32)
        else:
            raise NotImplementedError
        logging.info(f"The size of the datasets {self.dataset_name}: {point_set.shape[0]}.")
        return point_set

    def model_grad(self, samples):
        samples.requires_grad_(True)
        return torch.autograd.grad(
            outputs=self.model(samples).sum(),
            inputs=samples,
            create_graph=True,
            retain_graph=True)[0]

    def plot_surface(self, points, eik_points, filename, resolution=100):
        pnts_val = self.model(points)
        pnts_grad = self.model_grad(points)
        points = points.detach()
        caption = ["model: {0}, grad: {1}".format(val.item(), g.item()) for val, g in
                   zip(pnts_val.squeeze(), pnts_grad.norm(2, dim=1))]
        trace_pnts = get_threed_scatter_trace(points, caption=caption, color='green')

        eik_pnts_val = self.model(eik_points)
        eik_pnts_grad = self.model_grad(eik_points)
        eik_points = eik_points.detach()
        caption = ["model: {0}, grad: {1}".format(val.item(), g.item()) for val, g in
                   zip(eik_pnts_val[:, 0:1].squeeze(), eik_pnts_grad.norm(2, dim=1))]
        trace_eik_pnts = get_threed_scatter_trace(eik_points, caption=caption, color='red')

        surface = get_surface_trace(points, self.model, resolution)
        trace_surface = surface["mesh_trace"]

        fig1 = go.Figure(data=[trace_pnts, trace_eik_pnts] + trace_surface)

        fig1.write_image(f"{filename}.png")
        offline.plot(fig1, filename=f"{filename}.html", auto_open=False)

    def plot_sample(self, samples, savefig=None):
        samples = samples.detach().cpu().numpy()
        trace = [go.Scatter3d(x=samples[:, 0], y=samples[:, 1], z=samples[:, 2], mode='markers', marker=dict(size=3))]

        layout = go.Layout(title='3D Scatter plot')
        fig = go.Figure(data=trace, layout=layout)
        if savefig is not None:
            filename = self.expdir + f"/Samples_{savefig}"
            offline.plot(fig, filename=f'{filename}.html', auto_open=False)
            fig.write_image(f"{filename}.png")

    def get_sample_extend(self, sample):
        X_normal = sample + torch.randn_like(sample, device=self.device) * self.alpha_extend_sample
        return X_normal

    def get_loss(self, sample, sample_extend):
        grad = self.model_grad(sample_extend)
        grad_loss = ((grad.norm(dim=1) - 1) ** 2).mean()
        loss_fn = (torch.abs(self.model(sample))).mean() + self.alpha_loss * grad_loss
        return loss_fn, grad_loss

    def train(self):
        self.batch_size = 512
        self.alpha_loss = 0.1
        self.learning_rate = 1e-4
        self.plot_every = 2000

        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.loss_list = []
        for t in range(self.args.max_epochs + 1):
            random_idx = torch.randperm(self.dataset.shape[0])[:self.batch_size]
            S = torch.index_select(self.dataset, 0, random_idx).to(self.device)
            X = self.get_sample_extend(S)

            loss, grad_loss = self.get_loss(S, X)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if t % 50 == 0:
                self.tb_logger.add_scalar('loss', loss, global_step=t)
                self.tb_logger.add_scalar('grad_loss', grad_loss, global_step=t)
                self.loss_list.append(loss.detach().cpu().numpy())
            if t % (int(self.args.max_epochs/50)) == 0:
                logging.info(f'step: {t}/{self.args.max_epochs}, loss: {loss.item():.6f}, grad_loss: {grad_loss.item():.6f}')

            if t % (int(self.args.max_epochs/5)) == 0:
                self.plot_surface(points=S.detach(),
                                    eik_points=X.detach(),
                                    filename=f'{self.expdir}/iteration_{t}',
                                    resolution=100)
                
                fig = plt.figure()
                plt.plot(np.log(np.array(self.loss_list)), c='b')
                plt.title('log of mse loss')
                plt.savefig(os.path.join(self.expdir, "loss_vs_t_log.png"), dpi=300, bbox_inches='tight')
                plt.close(fig)
        
        logging.info("Finish training.")

    def run(self):
        self.plot_sample(self.dataset, savefig='data')

        if self.args.if_train:
            self.model = ImplicitNetwork(hidden_layers=self.hidden_layers, 
                                         beta=self.args.beta)
            self.model.to(self.device)

            self.train()
            torch.save(self.model, os.path.join(self.expdir, f"{self.dataset_name}_sdf.pt"))
            torch.save(self.model, f"./constraint/model/{self.dataset_name}_sdf.pt")
        else:
            path = f'./constraint/model/{self.dataset_name}_sdf.pt'
            self.model = torch.load(path, map_location=self.device)

        S = self.dataset[torch.randperm(self.dataset.shape[0])][:2000].clone().to(self.device)
        X = self.get_sample_extend(S)
        self.plot_surface(points=S.detach(),
                          eik_points=X.detach(),
                          filename=f'{self.expdir}/Surface_final',
                          resolution=100)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', type=str)
    parser.add_argument('--save_prefix', type=str, default="-")
    parser.add_argument('--alpha_extend_sample', type=float, default=0.05)
    parser.add_argument('--beta', type=float, default=10.)
    parser.add_argument('--max_epochs', type=int, default=200000)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--if_train', type=bool, default=True)
    args = parser.parse_args()


    Trainer = ConstraintTrainer(args)
    Trainer.run()



