#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import numpy as np
from PIL import Image
import torchvision
import os

def mse(img1, img2):
    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)

def psnr(img1, img2):
    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def easy_cmap(x: torch.Tensor, min=None, max=None):
    x_rgb = torch.zeros((3, x.shape[0], x.shape[1]), dtype=torch.float32, device=x.device)
    x_min = x.min() if min is None else min
    x_max = x.max() if max is None else max
    x_normalize = (x - x_min) / (x_max - x_min)
    x_rgb[0] = torch.clamp(x_normalize, 0, 1)
    x_rgb[1] = torch.clamp(x_normalize, 0, 1)
    x_rgb[2] = torch.clamp(x_normalize, 0, 1)
    return x_rgb

import imageio
from tqdm import tqdm
def save_video(data,images_path,target_fps=30,folder=None):
    if isinstance(data, np.ndarray):
        tensor_data = (torch.from_numpy(data) * 255).to(torch.uint8)
    elif isinstance(data, torch.Tensor):
        tensor_data = (data.detach().cpu() * 255).to(torch.uint8)
    elif isinstance(data, list):
        folder = [folder]*len(data)
        images = [np.array(Image.open(os.path.join(folder_name,path))) for folder_name,path in zip(folder,data)]
        stacked_images = np.stack(images, axis=0)
        tensor_data = torch.from_numpy(stacked_images).to(torch.uint8)
    # torchvision.io.write_video(images_path, tensor_data, fps=target_fps, video_codec='h264', options={'crf': '10'})
    frames = tensor_data.cpu().numpy()
    # 保存为 mp4
    writer = imageio.get_writer(
        images_path,
        fps=target_fps,
        codec='libx264',
        ffmpeg_log_level='error',
        pixelformat='yuv444p',  # 支持奇数尺寸
        ffmpeg_params=['-crf', '10'],
        macro_block_size=None
        # quality=10, # 最高质量，接近无损压缩
    )
    for frame in tqdm(frames, desc="Writing video"):
        writer.append_data(frame)
    writer.close()