import os
import numpy as np
import torch
from tqdm import tqdm

from utils.loadconfig import load_config

from utils.dataset import ReenactmentDataset
from utils.dataloaderx import DataLoaderX
from utils.logger import ReenactmentRecorder
from models.gaussian_head import GaussianHeadModule
# from models.gaussian_head_without_pe import GaussianHeadModule
from modules.superresolution import SuperResolutionModule

def Reenactment(dataloader, gaussianhead, supres, logger, device, freeview):
    xyz_sequence = []
    for idx, data in enumerate(tqdm(dataloader)):

        to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', 
                    'pose', 'scale', 'exp_coeff', 'ear', 'pose_code']
        for data_item in to_cuda:
            data[data_item] = data[data_item].to(device)
        

        if not freeview:
            if idx > 0:
                data['pose'] = pose_last * 0.4 + data['pose'] * 0.6
                data['exp_coeff'] = exp_last * 0.4 + data['exp_coeff'] * 0.6
            pose_last = data['pose']
            exp_last = data['exp_coeff']
            
        else:
            data['pose'] = torch.zeros_like(data['pose'])
            if idx > 0:
                data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5
            exp_last = data['exp_coeff']


        with torch.no_grad():
            data = gaussianhead.generate(data)
            xyz_sequence.append(data['xyz'])
            
            data = gaussianhead.render.render_gaussian(data, 128)
            render_images = data['render_images']
            supres_images = supres(render_images)
            data['supres_images'] = supres_images

        to_log = {
            'data': data,
            'iter': idx
        }
        logger.log(to_log)
        
    xyz_sequence = torch.cat(xyz_sequence, dim=0)
    xyz_sequence = xyz_sequence.cpu().numpy()



if __name__ == '__main__':
    
    config_path = 'configs/reenactment_mono.yaml'
    cfg = load_config(config_path)

    dataset = ReenactmentDataset(cfg['dataset'])
    dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) 

    device = torch.device(cfg['device'])
    torch.cuda.set_device(device)

    gaussianhead_state_dict = torch.load(cfg['load_gaussianhead_checkpoint'], map_location=lambda storage, loc: storage)
    gaussianhead = GaussianHeadModule(cfg['gaussianheadmodule'], 
                                        xyz=gaussianhead_state_dict['xyz'], 
                                        feature=gaussianhead_state_dict['feature'],
                                        landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
    gaussianhead.load_state_dict(gaussianhead_state_dict)

    supres = SuperResolutionModule(cfg['supresmodule']).to(device)
    supres.load_state_dict(torch.load(cfg['load_supres_checkpoint'], map_location=lambda storage, loc: storage))
    
    recorder = ReenactmentRecorder(cfg['recorder'])

    Reenactment(dataloader, gaussianhead, supres, recorder, device, dataset.freeview)
