# -*- coding: utf-8 -*-
import argparse
import json
import os
import pickle

import lpips as lpips_lib
import torch
import torchvision
import tqdm
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from datasets.dataset_factory import get_dataset
from gaussian_renderer import render_predicted
from scene.gaussian_predictor import GaussianSplatPredictor
from utils.loss_utils import ssim as ssim_fn


class Metricator():
    def __init__(self, device):
        self.lpips_net = lpips_lib.LPIPS(net='vgg').to(device)

    def compute_metrics(self, image, target):
        lpips = self.lpips_net(image.unsqueeze(0) * 2 - 1, target.unsqueeze(0) * 2 - 1).item()
        psnr = -10 * torch.log10(torch.mean((image - target) ** 2, dim=[0, 1, 2])).item()
        ssim = ssim_fn(image, target).item()
        return psnr, ssim, lpips


@torch.no_grad()
def evaluate_dataset(model, dataloader, device, model_cfg, save_vis=0, save_pkl=0, out_folder=None, epoch='best'
                     ):
    """
    Runs evaluation on the dataset passed in the dataloader.
    Computes, prints and saves PSNR, SSIM, LPIPS.
    Args:
        save_vis: how many examples will have visualisations saved
    """

    if save_vis > 0:

        os.makedirs(out_folder, exist_ok=True)

    score_fp = f"{out_folder}/{epoch}_scores.txt" if out_folder is not None else f"{epoch}scores.txt"
    with open(score_fp, "w+") as f:
        f.write("")

    bg_color = [1, 1, 1] if model_cfg.data.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    # instantiate metricator
    metricator = Metricator(device)

    psnr_all_examples_novel = []
    ssim_all_examples_novel = []
    lpips_all_examples_novel = []

    psnr_all_examples_cond = []
    ssim_all_examples_cond = []
    lpips_all_examples_cond = []

    for d_idx, data in enumerate(tqdm.tqdm(dataloader)):

        psnr_all_renders_novel = []
        ssim_all_renders_novel = []
        lpips_all_renders_novel = []
        psnr_all_renders_cond = []
        ssim_all_renders_cond = []
        lpips_all_renders_cond = []

        data = {k: v.to(device) for k, v in data.items()}

        rot_transform_quats = data["source_cv2wT_quat"][:, :model_cfg.data.input_images]
        
        input_intrinsics = None

        if model_cfg.data.origin_distances:
            input_images = torch.cat([data["norm_imgs"][:, :model_cfg.data.input_images, ...],
                                      data["origin_distances"][:, :model_cfg.data.input_images, ...]],
                                     dim=2)

        else:
            input_images = data["norm_imgs"][:, :model_cfg.data.input_images, ...]

        if model_cfg.data.use_mask:
            input_masks = data["fg_masks"][:, :model_cfg.data.coarse_stage_input_images, ...]
        else:
            input_masks = None
        if model_cfg.data.use_plucker_emb:
            plucker_emb = data["plucker_emb"][:, :model_cfg.data.input_images, ...]
        else:
            plucker_emb = None

        if model_cfg.data.mod_camera_dec:
            input_cameras = data["source_camera"][:, :model_cfg.data.input_images, ...]
        else:
            input_cameras = None

        example_id = dataloader.dataset.get_example_id(d_idx)
        if d_idx < save_vis:

            out_example_gt = os.path.join(out_folder, "{}_".format(d_idx) + example_id + "_gt")
            out_example = os.path.join(out_folder, "{}_".format(d_idx) + example_id)

            os.makedirs(out_example_gt, exist_ok=True)
            os.makedirs(out_example, exist_ok=True)
        # else:
        #     os._exit(1)
        
        # batch has length 1, the first image is conditioning

        reconstruction = model(input_images,
                               input_masks,
                               input_intrinsics,
                               data["view_to_world_transforms"][:, :model_cfg.data.input_images, ...],
                               rot_transform_quats,
                               plucker_emb=plucker_emb,
                               input_cameras=input_cameras,
                               unnorm_imges=data["gt_images"][:, :model_cfg.data.input_images, ...],
                               source_cameras_view_to_world_coarse=data["view_to_world_transforms"][:, :model_cfg.data.input_images, ...],
                               source_cv2wT_quat_coarse=data["source_cv2wT_quat"][:, :model_cfg.data.input_images, ...])

        if d_idx < save_pkl:
            save_dict = {key: value[0].contiguous().cpu().numpy() for key, value in reconstruction[-1].items()}
            out_pkl = os.path.join(out_folder, 'pkl', "{}_".format(d_idx) + example_id)
            os.makedirs(f'{out_folder}/pkl', exist_ok=True)
            with open(f'{out_pkl}_recons.pkl', 'wb') as file:
                pickle.dump(save_dict, file)
        
        for r_idx in range(data["gt_images"].shape[1]):
            focals_pixels_render = None
            image = render_predicted({k: v[0].contiguous() for k, v in reconstruction[-1].items()},
                                     data["world_view_transforms"][0, r_idx],
                                     data["full_proj_transforms"][0, r_idx],
                                     data["camera_centers"][0, r_idx],
                                     background,
                                     model_cfg,
                                     focals_pixels=focals_pixels_render)["render"]


            if d_idx < save_vis:
                # vis_image_preds(reconstruction, out_example)
                torchvision.utils.save_image(image, os.path.join(out_example, '{0:05d}'.format(r_idx) + ".png"))
                torchvision.utils.save_image(data["gt_images"][0, r_idx, ...], os.path.join(out_example_gt, '{0:05d}'.format(r_idx) + ".png"))

            # exclude non-foreground images from metric computation
            if not torch.all(data["gt_images"][0, r_idx, ...] == 0):
                psnr, ssim, lpips = metricator.compute_metrics(image, data["gt_images"][0, r_idx, ...])
                if r_idx < model_cfg.data.input_images:
                    psnr_all_renders_cond.append(psnr)
                    ssim_all_renders_cond.append(ssim)
                    lpips_all_renders_cond.append(lpips)
                else:
                    psnr_all_renders_novel.append(psnr)
                    ssim_all_renders_novel.append(ssim)
                    lpips_all_renders_novel.append(lpips)

        psnr_all_examples_cond.append(sum(psnr_all_renders_cond) / len(psnr_all_renders_cond))
        ssim_all_examples_cond.append(sum(ssim_all_renders_cond) / len(ssim_all_renders_cond))
        lpips_all_examples_cond.append(sum(lpips_all_renders_cond) / len(lpips_all_renders_cond))

        psnr_all_examples_novel.append(sum(psnr_all_renders_novel) / len(psnr_all_renders_novel))
        ssim_all_examples_novel.append(sum(ssim_all_renders_novel) / len(ssim_all_renders_novel))
        lpips_all_examples_novel.append(sum(lpips_all_renders_novel) / len(lpips_all_renders_novel))

        with open(score_fp, "a+") as f:
            f.write("{}_".format(d_idx) + example_id +
                    " " + str(psnr_all_examples_novel[-1]) +
                    " " + str(ssim_all_examples_novel[-1]) +
                    " " + str(lpips_all_examples_novel[-1]) + "\n")

    scores = {"PSNR_cond": sum(psnr_all_examples_cond) / len(psnr_all_examples_cond),
              "SSIM_cond": sum(ssim_all_examples_cond) / len(ssim_all_examples_cond),
              "LPIPS_cond": sum(lpips_all_examples_cond) / len(lpips_all_examples_cond),
              "PSNR_novel": sum(psnr_all_examples_novel) / len(psnr_all_examples_novel),
              "SSIM_novel": sum(ssim_all_examples_novel) / len(ssim_all_examples_novel),
              "LPIPS_novel": sum(lpips_all_examples_novel) / len(lpips_all_examples_novel)}
    return scores


@torch.no_grad()
def main(dataset_name, experiment_path, device_idx, split='test', save_vis=0, save_pkl=0, out_folder=None, epoch='best', num_views=4):

    # set device and random seed
    device = torch.device("cuda:{}".format(device_idx))
    torch.cuda.set_device(device)

    cfg_path = os.path.join(experiment_path, ".hydra", "config.yaml")
    if os.path.exists(os.path.join(experiment_path, f"model_{epoch}.pth")):
        print("loading model ckpt from single card ckpt")
        model_path = os.path.join(experiment_path, f"model_{epoch}.pth")
        deepspeed_ckpt = False
    else:
        print("loading model ckpt from deepspeed ckpt")
        model_path = os.path.join(experiment_path, f"model_{epoch}/checkpoint/mp_rank_00_model_states.pt")
        deepspeed_ckpt = True

    # load cfg
    training_cfg = OmegaConf.load(cfg_path)
    setattr(training_cfg.data, 'input_images', num_views)
    setattr(training_cfg.data, 'category', dataset_name)

    model = GaussianSplatPredictor(training_cfg)
    ckpt_loaded = torch.load(model_path, map_location=device)
    if deepspeed_ckpt:
        model.load_state_dict(ckpt_loaded["module"])
    else:
        model.load_state_dict(ckpt_loaded["model_state_dict"])
    model = model.to(device)
    model.eval()
    print('Loaded model!')


    # override dataset in cfg if testing objaverse model
    if training_cfg.data.category == "objaverse" and split in ["test", "vis"]:
        training_cfg.data.category = "gso"
    # instantiate dataset loader
    dataset = get_dataset(training_cfg, split)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False,
                            persistent_workers=True, pin_memory=True, num_workers=1)
    out_folder = f'{out_folder}_{training_cfg.data.category}'
    os.makedirs(out_folder, exist_ok=True)
    torch.cuda.empty_cache()

    scores = evaluate_dataset(model, dataloader, device, training_cfg, save_vis=save_vis, save_pkl=save_pkl, out_folder=out_folder, epoch=epoch)
    print(scores)
    return scores


def parse_arguments():
    parser = argparse.ArgumentParser(description='Evaluate model')
    parser.add_argument('dataset_name', type=str, help='Dataset to evaluate on',
                        choices=['objaverse', 'gso', 'objaverse_new', 'gso_new'])
    parser.add_argument('--experiment_path', type=str, default=None, help='Path to the parent folder of the model. \
                        If set to None, a pretrained model will be downloaded')
    parser.add_argument('--split', type=str, default='test', choices=['test', 'val', 'vis', 'train', 'val_vis_test'],
                        help='Split to evaluate on (default: test). \
                        Using vis renders loops and does not return scores - to be used for visualisation. \
                        You can also use this to evaluate on the training or validation splits.')
    parser.add_argument('--out_folder', type=str, default='out', help='Output folder to save renders (default: out)')
    parser.add_argument('--save_vis', type=int, default=0, help='Number of examples for which to save renders (default: 0)')
    parser.add_argument('--save_pkl', type=int, default=0, help='Number of examples for which to save save_pkl for point clouds (default: 0)')
    parser.add_argument('--epoch', type=str, default='best', help='Name of epoch that wants to do infer on')
    parser.add_argument('--num_views', type=int, default=4, help='Number of views in the inference process')
    return parser.parse_args()


if __name__ == "__main__":

    args = parse_arguments()

    dataset_name = args.dataset_name
    print("Evaluating on dataset {}".format(dataset_name))
    experiment_path = args.experiment_path
    if args.experiment_path is None:
        print("Will load a model released with the paper.")
    else:
        print("Loading a local model according to the experiment path")
    split = args.split
    if split == 'vis':
        print("Will not print or save scores. Use a different --split to return scores.")
    out_folder = args.out_folder
    print('Saving output in ', out_folder)
    save_vis = args.save_vis
    if save_vis == 0:
        print("Not saving any renders (only computing scores). To save renders use flag --save_vis")
    save_pkl=args.save_pkl
    scores = main(dataset_name, experiment_path, 0, split=split, save_vis=save_vis, save_pkl=save_pkl, out_folder=out_folder, epoch=args.epoch, num_views=args.num_views)
    # save scores to json in the experiment folder if appropriate split was used
    if split != "vis":
        if experiment_path is not None:
            score_out_path = os.path.join(experiment_path,
                                          "{}_scores.json".format(split))
        else:
            score_out_path = "{}_{}_scores.json".format(dataset_name, split)
        with open(score_out_path, "w+") as f:
            json.dump(scores, f, indent=4)
