import os, sys
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import trimesh
from loguru import logger

from pyutils import get_directory, to_device
from utils import implicit_utils
from utils.geom_utils import embedded_deformation

import point_cloud_utils as pcu


def save_obj(fname, vertices, faces):
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
    mesh.export(fname)


def train_one_epoch(state_info, config, train_loader, model, lat_vecs, optimizer_train, writer, latents_all=None):
    model.train()

    epoch = state_info['epoch']
    device = state_info['device']

    # ASAP
    if config.local_rank == 0 and (config.use_sdf_asap or config.use_mesh_arap):
        logger.warning("use ARAP/ASAP loss")

    for b, batch_dict in enumerate(train_loader):
        state_info['b'] = b
        optimizer_train.zero_grad()
        batch_dict = to_device(batch_dict, device)

        batch_vecs = None
        if lat_vecs is not None:
            batch_vecs = lat_vecs(batch_dict['idx']) # (B, latent_dim)
        if latents_all is not None:
            batch_dict['latents_all'] = latents_all

        batch_dict = model(batch_vecs, batch_dict, config, state_info) # (B, N, 3)
        batch_dict.update({k : v.mean() for k, v in batch_dict.items() if 'loss' in k})
        state_info.update({k : v.item() for k, v in batch_dict.items() if 'loss' in k})

        loss = batch_dict["loss"]
        loss.backward()
        optimizer_train.step()

        if config.local_rank == 0 and b % config.log.log_batch_interval == 0:
            global_step = (state_info['epoch'] * state_info['len_train_loader'] + b ) * config.optimization[config.rep].batch_size
            writer.log_state_info(state_info)
            writer.log_summary(state_info, global_step, mode='train')

    return state_info


def recon_from_lat_vecs(state_info, config, recon_loader, model, recon_lat_vecs, results_dir):
    '''
    Args:
        recon_lat_vecs: Embedding(D, latent_dim)
    '''
    model.eval()
    device = state_info['device']

    if config.rep == 'sdf':
        logger.info(" reconstruct mesh from sdf predicted by sdfnet ")

        parallel_idx_list = range(len(recon_loader.dataset))
        if config.parallel_idx >= 0:
            parallel_idx_list = parallel_idx_list[config.parallel_idx * config.parallel_interval : (config.parallel_idx + 1) * config.parallel_interval]
        logger.info(f"number of recon mesh: {len(parallel_idx_list)} in {len(recon_loader.dataset)}")

        for i, batch_dict in enumerate(recon_loader):
            if i not in parallel_idx_list:
                continue
            fid = recon_loader.dataset.fid_list[i]
            print(f"i={i}, fid={fid}")
            batch_dict = to_device(batch_dict, device)
            
            if config.auto_decoder:
                # get latent_vec from recon_lat_vecs
                assert(recon_lat_vecs is not None)
                latent_vec = recon_lat_vecs(batch_dict['idx']) # (B, latent_dim)
            else:
                assert(recon_lat_vecs is None)
                latent_vec, _, _ = model(None, batch_dict, config, None, only_encoder_forward=True)
            assert(latent_vec.shape[0] == 1) # batch_size == 1
            latent_vec = latent_vec[0]

            points_for_bound = batch_dict['points_mnfld'][0] if 'points_mnfld' in batch_dict else None
            x_range, y_range, z_range = config.loss.get('x_range', [-1, 1]), config.loss.get('y_range', [-0.7, 1.7]), config.loss.get('z_range', [-1.1, 0.9])
            # verts, faces = implicit_utils.sdf_decode_mesh_from_single_lat(model, latent_vec, resolution=128, voxel_size=None, max_batch=int(2 ** 18), offset=None, scale=None, points_for_bound=points_for_bound)
            verts, faces = implicit_utils.sdf_decode_mesh_from_single_lat(model, latent_vec, resolution=128, voxel_size=None, max_batch=int(2 ** 17), offset=None, scale=None, x_range=x_range, y_range=y_range, z_range=z_range)

            save_obj(f"{results_dir}/{fid}.obj", verts, faces)
            # with open(f"{results_dir}/{fid}.pkl", "wb") as f:
            #     dump_dict = {'points_mnfld': batch_dict['points_mnfld'][0].detach().cpu().numpy(),
            #                  'samples_nonmnfld': batch_dict['samples_nonmnfld'][0].detach().cpu().numpy(),
            #                  'recon_verts': verts, 'recon_faces': faces}
            #     if 'raw_mesh_verts' in batch_dict:
            #         dump_dict.update({'raw_mesh_verts': batch_dict['raw_mesh_verts'][0].detach().cpu().numpy()})
            #     if 'raw_mesh_faces' in batch_dict:
            #         dump_dict.update({'raw_mesh_faces': batch_dict['raw_mesh_faces'][0].detach().cpu().numpy()})
            #     pickle.dump(dump_dict, f)

    elif config.rep == 'mesh':
        logger.info(" reconstruct mesh from mesh predicted by meshnet ")

        parallel_idx_list = range(len(recon_loader.dataset))
        if config.parallel_idx >= 0:
            parallel_idx_list = parallel_idx_list[config.parallel_idx * config.parallel_interval : (config.parallel_idx + 1) * config.parallel_interval]
        logger.info(f"number of recon mesh: {len(parallel_idx_list)} in {len(recon_loader.dataset)}")

        for i, batch_dict in enumerate(recon_loader):
            if i not in parallel_idx_list:
                continue
            fid = recon_loader.dataset.fid_list[i]
            print(f"i={i}, fid={fid}")
            
            if config.auto_decoder:
                assert(recon_lat_vecs is not None)
                batch_dict = to_device(batch_dict, device)
                latent_vec = recon_lat_vecs(batch_dict['idx'])
                assert(latent_vec.shape[0] == 1) # batch_size == 1
                batch_dict = model(latent_vec, batch_dict, config, state_info=None)
            else:
                raise NotImplementedError

            verts_init = batch_dict['verts_init_nml'][0].detach().cpu().numpy() * model.data_std_gpu.numpy() + model.data_mean_gpu.numpy()
            verts_pred = batch_dict['mesh_verts_nml_pred'][0].detach().cpu().numpy() * model.data_std_gpu.numpy() + model.data_mean_gpu.numpy()
            verts_raw = batch_dict['verts_raw'][0].detach().cpu().numpy()
            faces_raw = batch_dict['faces_raw'][0].detach().cpu().numpy()
            template_faces = recon_loader.dataset.template_faces

            save_obj(f"{results_dir}/{fid}.obj", verts_pred, template_faces)
            with open(f"{results_dir}/{fid}.pkl", "wb") as f:
                dump_dict = {'verts_init': verts_init, 'verts_raw': verts_raw, 'faces_raw': faces_raw,
                             'verts_pred': verts_pred, 'template_faces': template_faces}
                pickle.dump(dump_dict, f)
    else:
        raise NotImplementedError


def interp_from_lat_vecs(state_info, config, interp_loader, model, interp_lat_vecs, results_dir):
    '''
    Args:
        recon_lat_vecs: Embedding(D, latent_dim)
    '''
    model.eval()
    device = state_info['device']
    num_interp = 10

    if config.split == 'train':
        src_fid = config.get('interp_src_fid', '50022-knees-knees.001582')
        tgt_fid = config.get('interp_tgt_fid', '50022-knees-knees.002101')
    elif config.split == 'test':
        src_fid = config.get('interp_src_fid', '50009-running_on_spot-running_on_spot.000366')
        tgt_fid = config.get('interp_tgt_fid', '50002-chicken_wings-chicken_wings.004011')
    else:
        raise NotImplementedError

    logger.info(" interpolate sdf predicted by sdfnet ")

    for i, batch_dict in enumerate(interp_loader):
        fid = interp_loader.dataset.fid_list[batch_dict['idx'][0]]
        if fid not in [src_fid, tgt_fid]:
            continue
        
        batch_dict = to_device(batch_dict, device)
        if config.auto_decoder:
            assert(interp_lat_vecs is not None)
            latent_vec = interp_lat_vecs(batch_dict['idx']) # (B, latent_dim)
        else:
            assert(interp_lat_vecs is None)
            latent_vec, _, _ = model(None, batch_dict, config, None, only_encoder_forward=True)

        assert(latent_vec.shape[0] == 1) # batch_size == 1
        latent_vec = latent_vec[0]
        if fid == src_fid:
            latent_src = latent_vec
            src_idx = i
        if fid == tgt_fid:
            latent_tgt = latent_vec
            tgt_idx = i

    logger.info(f"interpolate {src_fid} ({src_idx}-th) and {tgt_fid} ({tgt_idx}-th)")
    for i_interp in range(0, num_interp + 1): 
        ri = i_interp / num_interp

        latent_interp = latent_src * (1 - ri) + latent_tgt * ri

        dump_dir = get_directory( f"{results_dir}/{src_idx}_{tgt_idx}" )
        x_range, y_range, z_range = config.loss.get('x_range', [-1, 1]), config.loss.get('y_range', [-0.7, 1.7]), config.loss.get('z_range', [-1.1, 0.9])
        verts, faces = implicit_utils.sdf_decode_mesh_from_single_lat(model, latent_interp, resolution=128, max_batch=int(2 ** 17), offset=None, scale=None, x_range=x_range, y_range=y_range, z_range=z_range)

        save_obj(f"{dump_dir}/{src_idx}_{tgt_idx}_{i_interp:02d}.obj", verts, faces)

    def _copy_raw_mesh(_fid, _idx):
        _fname = '/'.join(_fid.split('-'))
        _fpath = f"{interp_loader.dataset.raw_mesh_dir}/{_fname}.{interp_loader.dataset.raw_mesh_file_type}"
        os.system(f"cp {_fpath} ./{dump_dir}/{_idx}.{interp_loader.dataset.raw_mesh_file_type}")

    _copy_raw_mesh(src_fid, src_idx)
    _copy_raw_mesh(tgt_fid, tgt_idx)



