from typing import *
import os
import time
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import glob
from PIL import Image
import argparse

import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch.utils.data import Dataset, DataLoader

import mcubes
import trimesh
import cubvh

from tropical.stanford.dataset import StanfordDataset
from tropical.stanford.model import Net
import tropical.subpoly as sp
from tropical.utils.chamfer_distance import sample_surface_from_rays, chamfer_distance


parser = argparse.ArgumentParser(
    prog="python -m tropical.stanford.train",
    description="Polyhedral complex derivation from piecewise trilinear networks")

parser.add_argument("-d", "--dataset", default="dragon",
                    choices=["bunny", "dragon", "happy", "armadillo", "drill", "lucy"],
                    help="Stanford 3D scanning model name")
parser.add_argument("-s", "--seed", default=45, type=int, help="Seed")
parser.add_argument("-m", "--model_size", default="small",
                    choices=["small", "medium", "large"], help="Model size")

args = parser.parse_args()
print(args)
seed = args.seed

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Constants
CANVAS_SIZE = 1.2
training_data_R = 0.8
our_t = -1

if "small" == args.model_size:
    r_min = 2
    r_max = 32
elif "medium" == args.model_size:
    r_min = 4
    r_max = 64
elif "large" == args.model_size:
    r_min = 8
    r_max = 128


# Model and mesh paths
model_path = os.path.join(
    os.path.dirname(__file__), 
    f"models/{args.dataset}/{args.dataset}_sdf_{args.model_size}_{seed}.pth")
mesh_path = os.path.join(
    f"meshes/{args.dataset}", f'our_mesh_{args.model_size}_{seed}.ply')

# Check if files exist
if not os.path.isfile(model_path):
    print(f"Model path is not found: {model_path}")
    exit()
if not os.path.isfile(mesh_path):
    print(f"Mesh path is not found: {mesh_path}")
    exit()

# Load the model and mesh
net = Net(num_layers=3, num_hidden=16, levels=4, r_min=r_min, r_max=r_max).cuda()
net.load_state_dict(torch.load(model_path, map_location=net.device()))
print(f"The pretrained model is loaded from {model_path}")
our_mesh = trimesh.load(mesh_path)
print(f"The mesh is loaded from {mesh_path}")
print(f"Ours: {our_mesh.vertices.shape}/{our_mesh.faces.shape}")


@torch.no_grad()
def run_marching_cubes(MC_SAMPLE_SIZE=100):
    # print(f"MC SAMPLE SIZE: {MC_SAMPLE_SIZE}")
    s = torch.linspace(-CANVAS_SIZE, CANVAS_SIZE, MC_SAMPLE_SIZE)
    grid_x, grid_y, grid_z = torch.meshgrid(s, s, s, indexing='ij')
    sdfs = net.sdf(
        torch.stack([grid_x, grid_y, grid_z], dim=-1).reshape(-1, 3).cuda())[:, 0]
    sdfs = sdfs.reshape(*grid_x.shape)
    sdfs = sdfs.cpu().numpy()

    vertices, triangles = mcubes.marching_cubes(-sdfs, 0)

    vertices = vertices / (MC_SAMPLE_SIZE - 1.0) * 2 * CANVAS_SIZE - CANVAS_SIZE
    vertices /= training_data_R
    mc_vertices = vertices.astype(np.float32)
    mc_triangles = triangles.astype(np.int32)
    mc_mesh = trimesh.Trimesh(mc_vertices, mc_triangles, process=False)
    # print(f"MC: {mc_mesh.vertices.shape}/{mc_mesh.faces.shape}")
    return mc_mesh


def get_rays(CD_SAMPLES_N=100000):
    theta = torch.rand(CD_SAMPLES_N) * 2 * torch.pi
    phi = torch.rand(CD_SAMPLES_N) * 2 * torch.pi
    x = 1 * torch.cos(theta) * torch.sin(phi)
    y = 1 * torch.sin(theta) * torch.sin(phi)
    z = 1 * torch.cos(phi)
    rays_d = torch.stack([x, y, z], 1)
    rays_o = torch.zeros_like(rays_d)
    return rays_o, rays_d


def angular_distance(x, y):
    deg = np.degrees(np.arccos(np.clip(np.sum(x * y, axis=-1), -1, 1)))
    mean = np.mean(deg)
    std = np.std(deg)
    return mean, std


rays_o, rays_d = get_rays()

# Additionaly, we could use the samples from inward rays.
# rays_o = torch.cat([rays_o, rays_d * 2], dim=0)
# rays_d = torch.cat([rays_d, -rays_d], dim=0)

our_samples, our_normals, our_mask = sample_surface_from_rays(rays_o, rays_d, our_mesh,
                                                              return_normal=True)

# Compute chamfer distance and angular distance
print(f"Marching Cubes Results:")
print("#samples, #vertices, CD, AD, time")
for i in [512, 16, 24, 32, 40, 48, 56, 64, 128, 192, 224, 225]:
    t = time.time()
    mc_mesh = run_marching_cubes(i)
    t = time.time() - t
    try:
        mc_samples, mc_normals, mc_mask = \
            sample_surface_from_rays(rays_o, rays_d, mc_mesh, return_normal=True)
    except:
        print(f"{i:4d}, {0:5d}, {0:0.6f}, {0:4.1f}, {t:.2f}")
        continue
    if 512 == i:
        gt_samples = mc_samples
        gt_normals = mc_normals
        gt_mask = mc_mask

        our_cd = chamfer_distance(our_samples, gt_samples, direction='bi')
        common_mask = our_mask & gt_mask
        our_ad, our_ad_std = angular_distance(
            our_normals[common_mask], gt_normals[common_mask])
        print(f"{'Ours'}, {our_mesh.vertices.shape[0]:5d}, {our_cd:0.6f}, "
              f"{our_ad:4.1f}, {our_t:.2f}")

    mc_cd = chamfer_distance(mc_samples, gt_samples, direction='bi')
    common_mask = mc_mask & gt_mask
    mc_ad, mc_ad_std = angular_distance(mc_normals[common_mask], gt_normals[common_mask])
    print(f"{i:4d}, {mc_mesh.vertices.shape[0]:5d}, {mc_cd:0.6f}, {mc_ad:4.1f}, {t:.2f}")
    mc_mesh.export(os.path.join(f"meshes/{args.dataset}",
                                f'mc{i:03d}_mesh_{args.model_size}_{seed}.ply'))
