import os
import torch
import numpy as np
import cv2
from tqdm import tqdm
import lpips
import argparse
from torchmetrics.image import StructuralSimilarityIndexMeasure

def extract_prompts_from_videos(video_folder):
    """
    从视频文件名中提取唯一的prompt名称
    假设文件名格式为: {prompt}-{index}.mp4
    """
    prompts = set()
    for filename in os.listdir(video_folder):
        if filename.endswith('.mp4'):
            # 提取prompt部分 (移除末尾的 -数字.mp4)
            base_name = filename[:-4]  # 移除.mp4
            if '-' in base_name:
                prompt = base_name.rsplit('-', 1)[0]
                prompts.add(prompt)
    return sorted(list(prompts))

def load_video_frames(path, resize_to=None):
    """
    Load all frames from a video file as a list of HxWx3 uint8 arrays.
    Optionally resize each frame to `resize_to` (w, h).
    """

    cap = cv2.VideoCapture(path)
    frames = []
    while True:
        ret, img = cap.read()
        if not ret:
            break
        if resize_to is not None:
            img = cv2.resize(img, resize_to)
        frames.append(np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), axis=0))
    cap.release()
    return np.concatenate(frames)

class VideoEvaluator:
    """
    一个用于评估视频帧指标（PSNR, SSIM, LPIPS）的类。
    """
    def __init__(self, device="cuda"):
        # 自动检测 GPU
        self.device = device if torch.cuda.is_available() else "cpu"
        
        # 初始化 LPIPS 模型
        self.ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
        self.lpips_fn = lpips.LPIPS(net="alex", spatial=True).to(device)

    def compute_video_metrics(self, frames_gt, frames_gen):
        """
        Compute PSNR, SSIM, LPIPS for two lists of frames (uint8 BGR).
        All computations on `device`.
        Returns (psnr, ssim, lpips) scalars.
        """
        # ensure same frame count
        # convert to tensors [N,3,H,W], normalize to [0,1]
        gt_t = torch.from_numpy(frames_gt).float().to(self.device).permute(0, 3, 1, 2).div_(255).contiguous()

        gen_t = torch.from_numpy(frames_gen).float().to(self.device).permute(0, 3, 1, 2).div_(255).contiguous()

        # PSNR (data_range=1.0): -10 * log10(mse)
        mse = torch.mean((gt_t - gen_t) ** 2)
        psnr = -10.0 * torch.log10(mse)

        # SSIM: returns average over batch
        ssim_val = self.ssim_metric(gen_t, gt_t)

        # LPIPS: expects [-1,1]
        with torch.no_grad():
            lpips_val = self.lpips_fn(gt_t * 2.0 - 1.0, gen_t * 2.0 - 1.0).mean()
        metrics = {
            "psnr": psnr.item(),
            "ssim": ssim_val.item(),
            "lpips": lpips_val.item(),
        }

        return metrics

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="Scaling", help="Evaluation mode (e.g., Scaling or Taylor)")
    parser.add_argument("--video_folder", type=str, default="./vbench_video_path/", 
                       help="Path to the parent folder containing Original and other video folders")
    args = parser.parse_args()
    
    # 配置路径
    origin_video_prefix = os.path.join(args.video_folder, "Original")
    
    # 从Original文件夹中的视频文件名提取prompt名称
    prompts = extract_prompts_from_videos(origin_video_prefix)
    num_videos_per_prompt = 5
    
    test_video_prefix = os.path.join(args.video_folder, args.mode)
    
    # 初始化评估器
    evaluator = VideoEvaluator()

    # 为每个 prompt 存储独立的指标
    prompt_metrics = {p: {"psnr": [], "ssim": [], "lpips": []} for p in prompts}

    for idx, prompt in tqdm(enumerate(prompts), desc="Evaluating prompts"):
        prompt_success = 0

        for i in range(num_videos_per_prompt):
            video_name = f"{prompt}-{i}.mp4"
            origin_path = os.path.join(origin_video_prefix, video_name)
            test_path = os.path.join(test_video_prefix, video_name)

            if not os.path.exists(origin_path) or not os.path.exists(test_path):
                continue

            try:
                frames_gt = load_video_frames(origin_path)
                frames_gen = load_video_frames(test_path)

                video_frame_metrics = evaluator.compute_video_metrics(frames_gt, frames_gen)
                if video_frame_metrics is None:
                    continue

                for k in prompt_metrics[prompt]:
                    prompt_metrics[prompt][k].append(video_frame_metrics[k])

                prompt_success += 1
            except Exception as e:
                print(f"[Error] {video_name}: {e}")
                continue

        # print(f"Prompt {idx} ({prompt}): {prompt_success}/{num_videos_per_prompt} videos evaluated")

    # 输出 Markdown 表格
    print("\n### Per-Prompt Video Quality Metrics")
    print("| Prompt | PSNR (↑) | SSIM (↑) | LPIPS (↓) |")
    print("|--------|----------|----------|-----------|")

    # 全局汇总
    global_psnr, global_ssim, global_lpips = [], [], []

    for prompt, metrics in prompt_metrics.items():
        if len(metrics["psnr"]) > 0:  # 只输出有数据的
            psnr_mean = np.mean(metrics["psnr"])
            ssim_mean = np.mean(metrics["ssim"])
            lpips_mean = np.mean(metrics["lpips"])

            global_psnr.extend(metrics["psnr"])
            global_ssim.extend(metrics["ssim"])
            global_lpips.extend(metrics["lpips"])

            print(f"| {prompt} | {psnr_mean:.4f} | {ssim_mean:.4f} | {lpips_mean:.4f} |")

    # 全局平均
    if len(global_psnr) > 0:
        psnr_avg = np.mean(global_psnr)
        ssim_avg = np.mean(global_ssim)
        lpips_avg = np.mean(global_lpips)
        print(f"| **Overall Avg** | {psnr_avg:.4f} | {ssim_avg:.4f} | {lpips_avg:.4f} |")