import os
import pdb
import cv2
import math
import shutil
import pandas as pd
from skimage.metrics import structural_similarity as ssim
import subprocess
from tqdm import tqdm
import numpy as np
import imagehash
from PIL import Image
from pytorch_msssim import ms_ssim
from pytorch_msssim import ssim as torch_ssim
import torch

def count_png_files(directory):
    count = 0
    for file in os.listdir(directory):
        if file.endswith('.png'):
            count += 1
    return count

def convert_to_h264(input_video, output_video):
    command = [
        'ffmpeg',
        '-i', input_video,
        '-c:v', 'libx264',
        '-preset', 'fast',
        output_video,
        '-y',
        '-loglevel', 'error'
    ]
    try:
        subprocess.run(command, check=True)
    except subprocess.CalledProcessError as e:
        print(f"{e}")
        exit(0)


def read_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("Error: Could not open video.")
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        yield frame
    
    cap.release()

def calculate_frame_ssim(frame1, frame2):
    gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
    return ssim(gray1, gray2, data_range=gray2.max() - gray2.min())

def extract_keyframes(input_video, output_folder):
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)

    os.makedirs(output_folder, exist_ok=True)
    command = [
        'ffmpeg',
        '-i', input_video,
        '-vf', 'select=eq(pict_type\\,I)',
        '-vsync', 'vfr',
        '-y', '-loglevel', 'error',
        f'{output_folder}/keyframe_%03d.png'
    ]
    try:
        subprocess.run(command, check=True)
    except subprocess.CalledProcessError as e:
        print(f"{e}")
        exit(0)

    count = count_png_files(output_folder)

    
    return count

def extract_keyframes_and_last_frame(input_video, output_folder):
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
    os.makedirs(output_folder, exist_ok=True)
    command_keyframes = [
        'ffmpeg',
        '-i', input_video,
        '-vf', 'select=eq(pict_type\\,I)',
        '-vsync', 'vfr', '-y',
        '-loglevel', 'error',
        f'{output_folder}/keyframe_%03d.png'
    ]
    subprocess.run(command_keyframes)

    key_frames = sorted([f for f in os.listdir(output_folder) if f.startswith('keyframe_') and f.endswith('.png')])
    last_index = int(key_frames[-1].split('_')[1].split('.')[0]) if key_frames else 0
    if len(key_frames) == 1:
        command_count_frames = [
            'ffprobe',
            '-v', 'error',
            '-select_streams', 'v:0',
            '-count_packets',
            '-show_entries', 'stream=nb_read_packets',
            '-loglevel', 'error',
            '-of', 'csv=p=0',
            input_video
        ]
        total_frames = subprocess.run(command_count_frames, capture_output=True, text=True)
        total_frames = int(total_frames.stdout.strip())
    
        command_last_frame = [
            'ffmpeg',
            '-i', input_video,
            '-vf', f"select='eq(n\\,{total_frames-1})'",
            '-vframes', '1',
            '-loglevel', 'error',
            f'{output_folder}/keyframe_{last_index+1:03d}.png'
        ]
        subprocess.run(command_last_frame)


def get_image_files(directory):
    return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.png')]

def load_images(image_folder):
    file_paths = get_image_files(image_folder)

    images = []
    for file_path in file_paths:
        img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
        if img is not None:
            images.append(img)
    return images

def calculate_ssim(frames):
    ssim_values = []
    for i in range(len(frames) - 1):
        ssim_value = ssim(frames[i], frames[i+1], data_range=frames[i+1].max() - frames[i+1].min())
        ssim_values.append(ssim_value)
    return sum(ssim_values) / len(ssim_values)

def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr


def stat_inter_keyframe_ssim(video_path, key_frame_folder):
    output_path = os.path.join('outputs', os.path.basename(video_path))
    convert_to_h264(video_path, output_path)
    key_frame_numbers = extract_keyframes_and_last_frame(output_path, output_folder=key_frame_folder)
    images = load_images(key_frame_folder)
    ssim_scores = calculate_ssim(images)
    return ssim_scores

def calculate_phash(image):
    return imagehash.phash(Image.fromarray(image))

def calculate_whash(image):
    return imagehash.whash(Image.fromarray(image))

def stat_ssim_with_path(video_path):
    frame_generator = read_frames(video_path)
    prev_frame = next(frame_generator, None)
    if prev_frame is None:
        return
    prev_phash = calculate_phash(prev_frame)
    prev_whash = calculate_whash(prev_frame)

    frame_count = 0
    ssims = []
    psnrs = []
    phashs= []
    whashs= []
    for current_frame in frame_generator:
        current_ssim = calculate_frame_ssim(prev_frame, current_frame)
        current_psnr = calculate_psnr(prev_frame, current_frame)
        current_phash = calculate_phash(current_frame)
        current_whash = calculate_whash(current_frame)

        phash_diff = current_phash - prev_phash if prev_phash is not None else 0
        whash_diff = current_whash - prev_whash if prev_whash is not None else 0

        prev_phash = current_phash
        prev_whash = current_whash
        frame_count += 1
        prev_frame = current_frame
        ssims.append(current_ssim)
        psnrs.append(current_psnr)
        phashs.append(phash_diff)
        whashs.append(whash_diff)
    return sum(ssims) / len(ssims), sum(psnrs) / len(psnrs), sum(phashs) / len(phashs), sum(whashs) / len(whashs)


def stat_ssim(video_data, topk=0.1):
    prev_frame = video_data[0]
    prev_phash = calculate_phash(prev_frame)
    prev_whash = calculate_whash(prev_frame)

    frame_count = 0
    ssims = []
    phashs= []
    
    for current_frame in video_data[1:]:
        current_ssim = calculate_frame_ssim(prev_frame, current_frame)
        current_phash = calculate_phash(current_frame)
        current_whash = calculate_whash(current_frame)

        phash_diff = current_phash - prev_phash if prev_phash is not None else 0
        whash_diff = current_whash - prev_whash if prev_whash is not None else 0

        prev_phash = current_phash
        prev_whash = current_whash
        frame_count += 1
        prev_frame = current_frame
        ssims.append(current_ssim)
        phashs.append(phash_diff)
    return sum(ssims) / len(ssims), sum(phashs) / len(phashs)


def stat_phash(video_data, topk=None):
    prev_frame = video_data[0]
    prev_phash = calculate_phash(prev_frame)

    phashs= []
    
    for current_frame in video_data[1:]:
        current_phash = calculate_phash(current_frame)
        phash_diff = current_phash - prev_phash if prev_phash is not None else 0
        prev_phash = current_phash
        prev_frame = current_frame
        phashs.append(phash_diff)
    if topk is None:
        return sum(phashs) / len(phashs)
    else:
        topk_phashs = sorted(phashs, reverse=True)[:topk]
        return sum(topk_phashs) / len(topk_phashs) 


import piqa

def cal_ssim_dist(video_data, org_videos, video_names, topk=0.1):
    device = video_data[0].device
    ssims = []
    msssims=[]
    psnrs = []
    phashs= []
    results = dict()
    piqa_ssim = piqa.SSIM().to(device)
    piqa_msssim = piqa.MS_SSIM().to(device)
    for i, (video, org_video) in enumerate(zip(video_data, org_videos)):
        if topk is None:
            phash = stat_phash(org_video.cpu().numpy())
            ssim = piqa_ssim((video[:-1]).clamp(0, 1), (video[1:]).clamp(0, 1)).item()
            msssim = piqa_msssim((video[:-1]).clamp(0, 1), (video[1:]).clamp(0, 1)).item()
        else:
            L = int(math.ceil(video.shape[0] * topk))
            phash = stat_phash(org_video.cpu().numpy(), L)
            ssim = (1 - torch_ssim(video[:-1], video[1:], data_range=1.0, size_average=False)).topk(L, largest=True)[0].mean().item()
            msssim = 1 - ms_ssim(video[:-1], video[1:], data_range=1.0, size_average=False).topk(L, largest=True)[0].mean().item()
        ssims.append(ssim)
        msssims.append(msssim)
        phashs.append(phash)

    return ssims, msssims, phashs


if __name__ == "__main__":
    video_folder = '/path/to/video'
    key_frame_folder = '.temp_key_frame_folder'
    output_csv =  'ssim_inter_frames.csv'
    scores = []
    for file in tqdm(os.listdir(video_folder)):
        if (not file.endswith('.mp4')) or file.startswith('._'):
            continue
        ssim_score, psnr, phash, whash = stat_ssim(os.path.join(video_folder, file))
        scores.append({'Video File': os.path.join(video_folder, file), 'ssim': ssim_score, 'psnr': psnr, 'phase': phash, 'whash': whash})
    df = pd.DataFrame(scores)
    df.to_csv(output_csv, index=False)
