"""
The whole framework
"""

import os
import joblib
import json
import numpy as np
import os.path as osp
from tqdm import tqdm
import imageio

import torch

from trainer.basetrainer import BaseTrainer
from models.renderer import RenderNet
from models.transmodel import ParticleNet
from datasets.dataset import BlenderDataset
from utils.particles_utils import record2obj
from utils.point_eval import FluidErrors
from utils import eval_utils


img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]).cuda())
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)

class Evaluator(BaseTrainer):
    def init_fn(self):
        self.start_step = 0
        self.eval_count = 0
        self.build_dataloader()
        self.build_model()
        init_particle_path = self.options.TRAIN.init_particle_path
        if init_particle_path:
            print('---> Initial position', init_particle_path)
            self.init_pos = torch.Tensor(np.load(init_particle_path)['particles']).to(self.device)
        else:
            self.init_pos = None

    def build_dataloader(self):
        self.test_viewnames = self.options['test'].views
        print('\033[1;35mtest view:\033[0m', self.test_viewnames)
        self.test_dataset = BlenderDataset(self.options.train.path, self.options, 
                                            imgW=self.options.TEST.imgW, imgH=self.options.TEST.imgH,
                                            imgscale=self.options.TEST.scale,viewnames=self.test_viewnames, split='train')
        self.test_dataset_length = len(self.test_dataset)
        print('---> dataloader has been build')
       
        
    def build_model(self):
        # build model
        gravity = self.options.gravity
        print('---> set gravity', gravity)
        self.transition_model = ParticleNet(gravity=gravity).to(self.device)
        self.renderer = RenderNet(self.options.RENDERER, near=self.options.near, far=self.options.far).to(self.device)
    
    def resume(self, ckpt_file):
        # resume
        checkpoint = torch.load(ckpt_file)
        self.renderer.load_state_dict(checkpoint['renderer_state_dict'], strict=True)
        self.transition_model.load_state_dict(checkpoint['transition_model_state_dict'], strict=True)
        print('---> model has been resumed from {}\n'.format(ckpt_file))

              
    def eval(self):
        """
        visulize the point cloud resutls, and the image
        """
        self.transition_model.eval()
        self.renderer.eval()
        view_num = len(self.test_viewnames)
        with torch.no_grad():
            dist_pred2gt_all = []
            rgbs = []
            psnrs = []
            ssims = []
            lpips_vgg = []
            rgbs_8 = []
            self.fluid_error = FluidErrors()
            for data_idx in tqdm(range(self.test_dataset_length), total=self.test_dataset_length):
                data = self.test_dataset[data_idx]
                keys = ['box', 'box_normals', 'particles_vel', 'cw_1', 'rgb_1', 'rays_1', 'focal', 'particles_pos_1']
                # data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in data.items()}
                data = {k: data[k].to(self.device) if isinstance(data[k], torch.Tensor) else data[k] for k in keys}
                
                box = data['box']
                box_normals = data['box_normals']
                if data_idx ==0:
                    if self.init_pos is not None:
                        pos_for_next_step = self.init_pos
                    else:
                        pos_for_next_step = data['particles_pos']
                    vel_for_next_step = data['particles_vel']
                    
                pred_pos, pred_vel, num_fluid_nn = self.transition_model(pos_for_next_step, vel_for_next_step, box, box_normals)
                pos_for_next_step, vel_for_next_step = pred_pos.clone(), pred_vel.clone()
                
                # --------
                # evaluate transition model
                # --------
                pos_t1 = data['particles_pos_1']
                # eval pred2gt distance
                dist_pred2gt = self.fluid_error.cal_errors(pred_pos.cpu().numpy(), pos_t1.cpu().numpy(), data_idx+1)
                dist_pred2gt_all.append(dist_pred2gt)
                # save to obj
                if not osp.exists(osp.join(self.particlepath, 'Pred')):
                    os.makedirs(osp.join(self.particlepath, 'Pred'))
                    os.makedirs(osp.join(self.particlepath, 'GT'))
                particle_name = osp.join(self.particlepath, 'Pred/%04d.obj' % (data_idx+1))
                with open(particle_name, 'w') as fp:
                    record2obj(pred_pos, fp, color=[255, 0, 0]) # red
                particle_name = osp.join(self.particlepath, 'GT/%04d.obj' % (data_idx+1))
                with open(particle_name, 'w') as fp:
                    record2obj(pos_t1, fp, color=[3, 168, 158])
                    
                # --------
                # evaluate renderer 
                # --------
                for view_idx in range(view_num):
                    view_name = self.test_viewnames[view_idx]
                    cw = data['cw_1'][view_idx]
                    ro = self.renderer.set_ro(cw)
                    focal_length = data['focal'][view_idx]
                    rgbs = data['rgb_1'][view_idx]
                    rays = data['rays_1'][view_idx].view(-1, 6)

                    render_ret = self.render_image(pred_pos, rays.shape[0], ro, rays, focal_length, cw, iseval=True)
                    # if data_idx >= 50:
                    #     self.renderer.cfg.use_mask = True
                    # pred_rgbs_0 = render_ret['pred_rgbs_0']
                    # prefix=f'coarse/{view_name}'
                    # self.visualization(pred_rgbs_0, rgbs, prefix, data_idx+1)
                    if self.options.RENDERER.ray.N_importance>0:
                        pred_rgbs_1 = render_ret['pred_rgbs_1']
                        prefix=f'fine/{view_name}'
                        pred_image, pred_rgb8 = self.visualization(pred_rgbs_1, rgbs, prefix, data_idx+1)
                        rgbs = rgbs.reshape(int(self.options.TEST.imgW // self.options.TEST.scale), int(self.options.TEST.imgH // self.options.TEST.scale), 3)
                        rgbs = rgbs.detach().cpu().numpy()
                rgbs_8.append(pred_rgb8)
                p = -10. * np.log10(np.mean(np.square(pred_image - rgbs)))
                psnrs.append(p)
                ssims.append(eval_utils.rgb_ssim(pred_image, rgbs, max_val=1))
                # lpips_alex.append(eval_utils.rgb_lpips(rgb, all_rgbs[data_idx], net_name='alex', device=self.device))
                lpips_vgg.append(eval_utils.rgb_lpips(rgbs.astype('float32'), pred_image.astype('float32'), net_name='vgg', device=self.device))
                print(psnrs)
                print(ssims)
                print(lpips_vgg)
            
            rgbs_8 = np.array(rgbs_8)
            imageio.mimwrite(os.path.join(self.exppath,f'video_fine.rgb.mp4'), rgbs_8, fps=24, quality=8)
            
            print('----------------- trained 50 steps ------------------------')
            print('Pred2GT:', np.mean(dist_pred2gt_all[0:49]))
            print('Pred2GT-10:', np.mean(dist_pred2gt_all[:10]))
            print('Pred2GT-end:', dist_pred2gt_all[48])
            print('avg_psnrs:', np.mean(psnrs[0:49]))
            print('avg_ssims',  np.mean(ssims[0:49]))
            print('avg_lpips (vgg)', np.mean(lpips_vgg[0:49]))
            
            print('\n----------------- rollout 10 steps ------------------------')
            print('Pred2GT:', np.mean(dist_pred2gt_all[-10:]))
            print('Pred2GT-5:', np.mean(dist_pred2gt_all[-5]))
            print('Pred2GT-end:', dist_pred2gt_all[-1])
            print('rollout-avg_psnrs', np.mean(psnrs[-10:]))
            print('rollout-avg_ssims', np.mean(ssims[-10:]))
            print('rollout-avg_lpips (vgg)', np.mean(lpips_vgg[-10:]))


            joblib.dump({'dist': dist_pred2gt_all}, osp.join(self.exppath, 'pred2gt.pt'))

            with open(os.path.join(self.exppath, 'mean.json'), 'w') as f:
                info = {}
                info['Pred2GT'] = np.mean(dist_pred2gt_all[0:49])
                info['Pred2GT-10'] = np.mean(dist_pred2gt_all[:10])
                info['Pred2GT-end'] = dist_pred2gt_all[48]

                info['rollout-Pred2GT'] = np.mean(dist_pred2gt_all[-10:])
                info['rollout-Pred2GT-5'] = np.mean(dist_pred2gt_all[-5])
                info['rollout-Pred2GT-end'] = dist_pred2gt_all[-1]
                info['avg_psnrs'] = np.mean(psnrs[0:49])
                info['avg_ssims'] = np.mean(ssims[0:49])
                info['avg_lpips (vgg)'] = np.mean(lpips_vgg[0:49])
                info['rollout-avg_psnrs'] = np.mean(psnrs[-10:])
                info['rollout-avg_ssims'] = np.mean(ssims[-10:])
                info['rollout-avg_lpips (vgg)'] = np.mean(lpips_vgg[-10:])

                info['psnrs'] = psnrs
                info['ssims'] = ssims
                info['lpips (vgg)'] = lpips_vgg

                info['Pred2GT'] = dist_pred2gt_all
                json.dump(info, f, indent=4)

        self.transition_model.train()
        self.renderer.train()
        

    def visualization(self, pred_rgbs, gt_rgbs,prefix=None, data_idx=0):
        pred_image = self.vis_rgbs(pred_rgbs)
        
        if not os.path.exists(osp.join(self.imgpath, prefix)):
            os.makedirs(osp.join(self.imgpath, prefix, 'Pred'))
        
        # save res
        # gt_rgb8 = to8b(gt_image)
        # filename = '{}/{}/GT/{:05d}.png'.format(self.imgpath, prefix, data_idx)
        # imageio.imwrite(filename, gt_rgb8)
        
        pred_rgb8 = to8b(pred_image)
        filename = '{}/{}/Pred/{:05d}.png'.format(self.imgpath, prefix, data_idx)
        imageio.imwrite(filename, pred_rgb8)
        return pred_image, pred_rgb8

    def vis_rgbs(self, rgbs, channel=3):
        imgW = int(self.options.TEST.imgW // self.options.TEST.scale)
        imgH = int(self.options.TEST.imgH // self.options.TEST.scale)
        image = rgbs.reshape(imgH, imgW, channel).detach().cpu().numpy()
        return image

   
if __name__ == '__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    from configs import end2end_training_config, dataset_config

    cfg_datasets = dataset_config()
    cfg_e2e = end2end_training_config()

    cfg_dataset = cfg_datasets[cfg_e2e.dataset]
    cfg_e2e.update(cfg_dataset)

    evaluator = Evaluator(cfg_e2e)
    evaluator.eval()
    