import os
import torch
import kaolin
from tqdm import tqdm
import torch.nn.functional as F

from utils.loadconfig import load_config
from utils.dataset import MeshDataset
from utils.dataloaderx import DataLoaderX
from utils.logger import MeshHeadTrainRecorder
from models.mesh_head import MeshHeadModule

from pytorch3d.transforms import so3_exponential_map


def laplace_regularizer_const(mesh_verts, mesh_faces):
    term = torch.zeros_like(mesh_verts)
    norm = torch.zeros_like(mesh_verts[..., 0:1])

    v0 = mesh_verts[mesh_faces[:, 0], :]
    v1 = mesh_verts[mesh_faces[:, 1], :]
    v2 = mesh_verts[mesh_faces[:, 2], :]

    term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, mesh_faces[:, 0:1], two)
    norm.scatter_add_(0, mesh_faces[:, 1:2], two)
    norm.scatter_add_(0, mesh_faces[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)

    return torch.mean(term**2)

def train(dataloader, model, optimizer, logger, device, epochs):
    model.train()
    for epoch in range(epochs):
        for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
            optimizer.zero_grad()
            to_cuda = ['images', 'masks', 'visibles', 'intrinsics', 'extrinsics', 'pose', 'scale', 'exp_coeff', 'ear', 'landmarks_3d', 'exp_id']
            for data_item in to_cuda:
                data[data_item] = data[data_item].to(device)
            images = data['images'].permute(0, 1, 3, 4, 2)
            masks = data['masks'].permute(0, 1, 3, 4, 2)
            visibles = data['visibles'].permute(0, 1, 3, 4, 2)
            resolution = images.shape[2]
            R = so3_exponential_map(data['pose'][:, :3])
            T = data['pose'][:, 3:, None]
            S = data['scale'][:, :, None]
            landmarks_3d_can = (torch.bmm(R.permute(0,2,1), (data['landmarks_3d'].permute(0, 2, 1) - T)) / S).permute(0, 2, 1)
            
            pred_landmarks_3d_can, sdf_landmarks_3d, render_images, render_soft_masks, exp_deform, pose_deform, verts_list, faces_list = \
            model(data)
            
            loss_def = F.mse_loss(pred_landmarks_3d_can, landmarks_3d_can)
            loss_lmk = torch.abs(sdf_landmarks_3d[:, :, 0]).mean()
            loss_rgb = F.l1_loss(render_images[:, :, :, :, 0:3] * visibles, images * visibles)
            loss_sil = kaolin.metrics.render.mask_iou((render_soft_masks * visibles[:, :, :, :, 0]).view(-1, resolution, resolution), (masks * visibles).squeeze().view(-1, resolution, resolution))
            loss_offset = (exp_deform ** 2).sum(-1).mean() + (pose_deform ** 2).sum(-1).mean() 
            loss_lap = 0.0
            for b in range(len(verts_list)):
                loss_lap += laplace_regularizer_const(verts_list[b], faces_list[b])
            
            loss = loss_rgb * 1e-1 + loss_sil * 1e-1 + loss_def * 1e0 + loss_offset * 1e-2 + loss_lmk * 1e-1 + loss_lap * 1e2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            to_log = {
                'data': data,
                'meshhead' : model,
                'loss_rgb' : loss_rgb,
                'loss_sil' : loss_sil,
                'loss_def' : loss_def,
                'loss_offset' : loss_offset,
                'loss_lmk' : loss_lmk,
                'loss_lap' : loss_lap,
                'epoch' : epoch,
                'iter' : i + epoch * len(dataloader),
                'loader_length' : len(dataloader)
            }
            logger.log(to_log)
                
if __name__ == '__main__':
    
    conf_path = 'configs/train_meshhead_mono.yaml'
    cfg = load_config(conf_path)

    device = torch.device(cfg['device'])
    torch.cuda.set_device(device)

    dataset = MeshDataset(cfg['dataset'])
    dataloader = DataLoaderX(dataset, batch_size=cfg['batch_size'], shuffle=True, pin_memory=True) 

    meshhead = MeshHeadModule(cfg['meshheadmodule'], dataset.init_landmarks_3d_neutral).to(device)
    if os.path.exists(cfg['load_meshhead_checkpoint']):
        meshhead.load_state_dict(torch.load(cfg['load_meshhead_checkpoint'], map_location=lambda storage, loc: storage))
    else:
        meshhead.pre_train_sphere(300, device)
        
    recorder = MeshHeadTrainRecorder(cfg['recorder'])

    optimizer = torch.optim.Adam([{'params' : meshhead.landmarks_3d_neutral, 'lr' : float(cfg['lr_lmk'])},
                                  {'params' : meshhead.geo_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                                  {'params' : meshhead.exp_color_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                                  {'params' : meshhead.pose_color_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                                  {'params' : meshhead.exp_deform_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                                  {'params' : meshhead.pose_deform_mlp.parameters(), 'lr' : float(cfg['lr_net'])},
                                  {'params' : meshhead.pts_embedding.parameters(), 'lr' : float(cfg['lr_net'])},
                                  ])
    
    train(dataloader, meshhead, optimizer, recorder, device, epochs=20)
    

