from lib2to3.pgen2.tokenize import TokenError
import sys, os

import argparse
from datetime import datetime
import torch
from tensorboardX import SummaryWriter
import random
import numpy as np
import signal
import math
import json
import torch.nn as nn
from timeit import default_timer as timer
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

import config as cfg
from dataset.dataset_deform4d_animal_abitraryflow import MeshDataset as Dataset
from utils import gradient_utils
from utils.time_statistics import TimeStatistics
from nnutils.geometry import augment_grid
from nnutils.learningrate import adjust_learning_rate, StepLearningRateSchedule, get_learning_rates

from flow.model import Flow
from flow.model_arbitrary import FlowArbitrary
from flow.loss_arbitrary import LossFlow
from flow.evaluate_arbitrary import evaluate
import open3d as o3d
import trimesh
from utils.viz_utils import vis_error_map
from utils.time_statistics import AverageMeter
from nnutils.eval_metric import chamfer_distance, normal_consistency, compute_dist_square



def main():
    torch.set_num_threads(cfg.num_threads)
    
    #torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    
    # Parse command line arguments.
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', action='store', dest='data', help='Provide a subfolder with training data')
    parser.add_argument('--experiment', action='store', dest='experiment', help='Provide an experiment name')
    parser.add_argument('--pose_model_path', action='store', dest='pose_model_path', help='Provide a path to pre-trained graph model')
    parser.add_argument('--load_obj', action='store_true', help='If use mesh structure ')
    parser.add_argument('--interval', type=int, default=3, help='the interval of pair')
    parser.add_argument('--num_val_data', type=int, default=0, help='numbe of val data')
    parser.add_argument('--ngpus', type=int, default=1, help='the number of gpus')
    parser.add_argument('--partial_range', type=float, default=0.1, help='partial range')
    parser.add_argument('--num_surf_samples', type=int, default=5000, help='the number of sampled point cloud as input')
    parser.add_argument('--unseen_iden', action='store_true', help='unseen identities')

    parser.add_argument('--model_cano', type=str, default='', help='')
    parser.add_argument('--model_deform', type=str, default='', help='')
    parser.add_argument('--use_normals', action='store_true', help='if use surface normals')
    parser.add_argument('--no_transformer', action='store_true', help='if not use transformer as encoder in airnet model')
    parser.add_argument('--interp_dec', action='store_true', help='if use interpolate decoder in airnet model')
    parser.add_argument('--global_field', action='store_true', help='if only use global latent code in airnet model')
    args = parser.parse_args()

    # Train set on which to actually train
    data = args.data

    # Experiment
    experiment_name = args.experiment
    
    pose_model_path = args.pose_model_path
    
    load_obj = args.load_obj
    
    partial_range = args.partial_range

    use_normals = args.use_normals

    random_handles = args.random_handles

    num_surf_samples = args.num_surf_samples
        
        
    print("Will initialize from provided checkpoint")
    print()

    # Print hyperparameters
    cfg.print_hyperparams(data, experiment_name)

    print()

    #####################################################################################
    # Creating tf writer and folders 
    #####################################################################################       
    data_dir = os.path.join(cfg.data_deform4d_seq_root_dir, data)
    experiment_dir = os.path.join(cfg.experiments_dir, experiment_name)
    checkpoints_dir = None

    
    # Creation of deformation output directories.
    if args.unseen_iden:
        generation_dirname = "unseen_iden"
    else:
        generation_dirname = "unseen_motion"
    generation_dirname += "_ratio%.2f"%partial_range
    output_dir = os.path.join(experiment_dir, generation_dirname)
    if not os.path.exists(output_dir): os.mkdir(output_dir)

    # We count the execution time between evaluations.
    time_statistics = TimeStatistics()

    #####################################################################################
    # Create datasets and dataloaders
    #####################################################################################
    # Augmentation is currently not supported for shape training.
    if args.unseen_iden:
        val_dataset = Dataset(
            data_dir, cfg.flow_num_point_samples, 
            cache_data=False, use_augmentation=False, 
            partial_range = partial_range,
            iden_split = cfg.iden_unseen_split,
            split = cfg.val_unseen_iden_clean_split,
            interval = args.interval, 
            load_obj = args.load_obj,
            num_data = args.num_val_data,
            use_normals = use_normals,
            num_surf_samples = num_surf_samples,
        )
    else:
        val_dataset = Dataset(
            data_dir, cfg.flow_num_point_samples, 
            cache_data=False, use_augmentation=False, 
            partial_range = partial_range,
            iden_split = cfg.iden_seen_split,
            split = cfg.val_unseen_motion_clean_split,
            interval = args.interval, 
            load_obj = args.load_obj,
            num_data = args.num_val_data,
            use_normals = use_normals,
            num_surf_samples = num_surf_samples,
        )

    if load_obj:
        cfg.flow_batch_size = 1
    else:
        cfg.flow_batch_size = cfg.flow_batch_size * args.ngpus
    cfg.shuffle = False 
    print('Real flow batch size', cfg.flow_batch_size) 
    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset, batch_size=cfg.flow_batch_size, shuffle=cfg.shuffle, num_workers=cfg.num_worker_threads, pin_memory=False
    )

    print("Num. training samples: {0}".format(len(val_dataset)))
    print()

    if len(val_dataset) < cfg.flow_batch_size:
        print()
        print("Reduce the batch_size, since we only have {} training samples but you indicated a batch_size of {}".format(
            len(val_dataset), cfg.flow_batch_size)
        )
        exit()

    #####################################################################################
    # Initializing: model, criterion, optimizer...
    #####################################################################################
    # Set the iteration number

    iteration_number = 0
    model_canonicalize = Flow(use_normals=use_normals, no_input_corr=True, \
        no_transformer=args.no_transformer, interp_dec=args.interp_dec, \
        global_field=args.global_field).cuda()
    pretrained_dict = torch.load(args.model_cano)['model_state_dict']
    model_canonicalize.load_state_dict(pretrained_dict)
    print('load model canonicalize from :', args.model_cano)

    model_deform = Flow(use_normals=use_normals, no_input_corr=False, \
        no_transformer=args.no_transformer, interp_dec=args.interp_dec, \
        global_field=args.global_field).cuda()
    pretrained_dict = torch.load(args.model_deform)['model_state_dict']
    model_deform.load_state_dict(pretrained_dict)
    print('load model deform from :', args.model_deform)
    model = FlowArbitrary(model_canonicalize=model_canonicalize, model_deform=model_deform).cuda()
    
    # Initialize with other model
    if os.path.exists(pose_model_path):
        print(f"Initializing from model: {pose_model_path}")
        print() 
        # Load pretrained dict
        pretrained_dict = torch.load(pose_model_path)['model_state_dict']
        model.load_state_dict(pretrained_dict)

    # Count parameters.
    n_all_model_params = int(sum([np.prod(p.size()) for p in model.parameters()]))
    n_trainable_model_params = int(sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())]))
    print("Number of parameters: {0} / {1}".format(n_trainable_model_params, n_all_model_params))
    print()

    # Execute training.
    complete_cycle_start = timer()
    l2_error_meter, face_normals_consistency_meter, chamfer_l1_meter = AverageMeter('l2'), AverageMeter('fnc'), AverageMeter('cd_l1')
    meta_info_dict = {}
    # model.eval()
    for i, data in enumerate(val_dataloader):            
        #####################################################################################
        ####################################### Train #######################################
        #####################################################################################
        if i < iteration_number:
            continue
        
        #####################################################################################
        # Data loading
        #####################################################################################
        if load_obj:
            flow_faces, flow_edges, flow_vertices, flow_faces_adjacency, surface_samples, surface_normals, surface_normals, rotated2gaps, bbox_lower, bbox_upper, temp_idx, sample_idx = data
            surface_normals          = val_dataset.unpack(surface_normals).cuda()
            #flow_vertex_normals     = val_dataset.unpack(flow_vertex_normals).cuda()
            flow_faces               = val_dataset.unpack(flow_faces).cuda()
            flow_edges               = val_dataset.unpack(flow_edges).cuda()
            flow_vertices            = val_dataset.unpack(flow_vertices).cuda()
            flow_faces_adjacency     = val_dataset.unpack(flow_faces_adjacency).cuda()
        else:
            uniform_samples, near_surface_samples, surface_samples, flow_samples, grid, surface_normals, surface_normals, rotated2gaps, bbox_lower, bbox_upper, temp_idx, sample_idx = data
            surface_normals          = val_dataset.unpack(surface_normals).cuda()
            uniform_samples         = val_dataset.unpack(uniform_samples).cuda()
            near_surface_samples    = val_dataset.unpack(near_surface_samples).cuda()
            flow_samples            = val_dataset.unpack(flow_samples).cuda()
        
        surface_samples         = val_dataset.unpack(surface_samples).cuda()
        rotated2gaps            = val_dataset.unpack(rotated2gaps).cuda()

        # Merge uniform and near surface samples.
        batch_size = surface_samples.shape[0]

        meta_data = val_dataset.get_metadata(sample_idx.item())
        identity_idx, identity_model_name, identity_seq_idx, model_name0, seq_idx0, model_name1, seq_idx1 = meta_data["pair_info"]
        print(identity_idx, identity_model_name, identity_seq_idx, model_name0, seq_idx0, model_name1, seq_idx1)     
        meta_info_dict[i] = {"identity_model_name": identity_model_name, "identity_seq_idx": identity_seq_idx,
                         "source_model_name": model_name0, "source_seq_idx": seq_idx0, 
                         "target_model_name": model_name1, "target_seq_idx": seq_idx1}
            
        # save results
        model.eval()
        
        # Forward pass.
        train_batch_start = timer()
        if use_normals:
            query_points0_canonicalize, surface_samples0_canonicalize, flow_vertices1_deformed = model(flow_vertices, surface_samples[:, 0:2], surface_samples[:, 2], surface_normals)
        else:
            query_points0_canonicalize, surface_samples0_canonicalize, flow_vertices1_deformed = model(flow_vertices, surface_samples[:, 0:2], surface_samples[:, 2])

        output_template_dir = os.path.join(output_dir,  "obj_meshes_source")
        if not os.path.exists(output_template_dir):
            os.makedirs(output_template_dir)
        output_file = os.path.join(output_template_dir, "%s_%04d.obj"%(model_name0, seq_idx0))
        vertices_template = flow_vertices[:, 0].squeeze().cpu().numpy()
        triangles = flow_faces.squeeze().cpu().numpy()
        handle_vert_idx = (flow_vertices[:, 2, :, 0:1]>0).squeeze().cpu().numpy()
        vert_colors = np.ones(vertices_template.shape, dtype=np.float32) * 0.75
        handle_colors = np.array([[1.0, 0.0, 0.0]], dtype=np.float32)
        vert_colors[handle_vert_idx, :] = handle_colors.repeat(handle_vert_idx.sum(), axis=0)
        mesh_template = trimesh.Trimesh(vertices=vertices_template, faces=triangles, vertex_colors=(vert_colors*255).astype('uint8'), process=False)
        mesh_template.export(output_file)

        output_canonical_dir = os.path.join(output_dir,  "obj_meshes_canonical")
        if not os.path.exists(output_canonical_dir):
            os.makedirs(output_canonical_dir)
        output_file = os.path.join(output_canonical_dir, "%s_%04d.obj"%(identity_model_name, identity_seq_idx))
        vertices_canonical = flow_vertices[:, 3].squeeze().cpu().numpy()
        mesh_canonical = trimesh.Trimesh(vertices=vertices_canonical, faces=triangles, vertex_colors=(vert_colors*255).astype('uint8'), process=False)
        mesh_canonical.export(output_file)

        output_deform_dir = os.path.join(output_dir, "obj_meshes_deformed")
        if not os.path.exists(output_deform_dir):
            os.makedirs(output_deform_dir)
        vertices_deformed = flow_vertices1_deformed.squeeze().cpu().detach().numpy()

        output_deform_gt_dir = os.path.join(output_dir,  "obj_meshes_target")
        if not os.path.exists(output_deform_gt_dir):
            os.makedirs(output_deform_gt_dir)
        vertices_deformed_gt = flow_vertices[:, -1].squeeze().cpu().numpy()

        output_handle_dir = os.path.join(output_dir,  "obj_meshes_handles")
        if not os.path.exists(output_handle_dir):
            os.makedirs(output_handle_dir)
        
        mesh_deformed_trimesh = trimesh.Trimesh(vertices_deformed, triangles, process=False)
        output_file = os.path.join(output_deform_dir, "%s_%04d_%s_%04d.obj"%(model_name0, seq_idx0, model_name1, seq_idx1))
        mesh_deformed = vis_error_map(vertices_deformed, triangles, np.sqrt(((vertices_deformed - vertices_deformed_gt)**2).sum(-1)))
        o3d.io.write_triangle_mesh(output_file, mesh_deformed)

        output_file = os.path.join(output_deform_gt_dir, "%s_%04d_%s_%04d.obj"%(model_name0, seq_idx0, model_name1, seq_idx1))
        mesh_deformed_gt = trimesh.Trimesh(vertices_deformed_gt, triangles, process=False)
        mesh_deformed_gt.export(output_file)

        pointcloud_size = 30000
        face_normals = mesh_deformed_trimesh.face_normals.astype(np.float32)
        face_normals_gt = mesh_deformed_gt.face_normals.astype(np.float32)
        _, face_idx = mesh_deformed_trimesh.sample(pointcloud_size, return_index=True)
        alpha = np.random.dirichlet((1,)*3, pointcloud_size)
        v = vertices_deformed[triangles[face_idx]]
        points = (alpha[:, :, None] * v).sum(axis=1)
        v = vertices_deformed_gt[triangles[face_idx]]
        points_gt = (alpha[:, :, None] * v).sum(axis=1)
    
        # l2 
        l2_error = compute_dist_square(vertices_deformed, vertices_deformed_gt)
        l2_error_meter.update(l2_error) 
        face_normals_consistency = normal_consistency(face_normals, face_normals_gt)
        if face_normals_consistency <= 1.0:
            face_normals_consistency_meter.update(face_normals_consistency)
        # CD
        chamfer_l1 = chamfer_distance(points, points_gt)
        chamfer_l1_meter.update(chamfer_l1)
        print(i, "L2_error:", l2_error, "face_nc:", face_normals_consistency, "chamfer_l1:", chamfer_l1)
        print("L2_error_avg:", l2_error_meter.avg, "face_nc_avg:", face_normals_consistency_meter.avg, "chamfer_l1_avg:", chamfer_l1_meter.avg)
        
        output_file = os.path.join(output_handle_dir, "%s_%04d_%s_%04d.obj"%(model_name0, seq_idx0, model_name1, seq_idx1))
        face_npy = triangles.reshape(-1)
        face_select_idx = handle_vert_idx[face_npy].reshape(-1, 3)
        face_select_idx = (face_select_idx.sum(axis=1)) == 3
        mesh_deformed_gt.update_faces(face_select_idx)
        mesh_deformed_gt.export(output_file)

    
    json_file = os.path.join(output_dir, "meta_info.json")
    with open(json_file, 'w') as fp:
        json.dump(meta_info_dict, fp)
    print("write meta info dict to:", json_file)
    print()
    print("I'm done")


if __name__=="__main__":
    main()