import os
import torch
import numpy as np
from PIL import Image
import lpips
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import glob
from tqdm import tqdm
import cv2
import imageio
import json

def load_images_from_folder(folder_path):
    """Load all images from a folder in sorted order."""
    image_files = sorted(glob.glob(os.path.join(folder_path, "*.png")))
    images = []
    for img_path in image_files:
        img = Image.open(img_path)
        img = np.array(img)
        images.append(img)
    return np.array(images)

def load_video_frames(video_path, height, width):
    """Load frames from video file."""
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = cv2.resize(frame, (width, height))
        frames.append(frame)
    cap.release()
    return np.array(frames)

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    return peak_signal_noise_ratio(img1, img2, data_range=255)

def calculate_psnr_star(img1, img2):
    """Calculate PSNR between two images with alignment."""
    pred = torch.from_numpy(img2/255.0).permute(2, 0, 1).unsqueeze(0)
    gt = torch.from_numpy(img1/255.0).permute(2, 0, 1).unsqueeze(0)
    gt_mean = gt.contiguous().view(gt.shape[0], -1).mean(dim=1)
    pred_mean = pred.contiguous().view(gt.shape[0], -1).mean(dim=1)
    ratio = gt_mean/pred_mean
    ratio = ratio.repeat(gt.shape[1],gt.shape[2],gt.shape[3],1).permute(3,0,1,2)
    pred = torch.clamp(pred*ratio, 0, 1)
    # convert back
    # pred = pred.squeeze(0).permute(1,2,0).numpy().astype(np.uint8)*255 # 导致inf
    # gt = gt.squeeze(0).permute(1,2,0).numpy().astype(np.uint8)*255
    pred = pred.squeeze(0).data.mul(255).byte().permute(1,2,0).cpu().numpy()
    gt = gt.squeeze(0).data.mul(255).byte().permute(1,2,0).cpu().numpy()
    return peak_signal_noise_ratio(gt, pred, data_range=255)

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images."""
    return structural_similarity(img1, img2, channel_axis=2, data_range=255)

def calculate_lpips(img1, img2, loss_fn):
    """Calculate LPIPS between two images."""
    # Convert to torch tensor and normalize to [-1, 1]
    img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1
    img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1
    
    if torch.cuda.is_available():
        img1 = img1.cuda()
        img2 = img2.cuda()
    
    with torch.no_grad():
        lpips_value = loss_fn(img1, img2)
    
    return lpips_value.item()

def evaluate_metrics(gt_path, generated_video_path, save_path=None, visualize=False):
    """
    Evaluate PSNR, SSIM, and LPIPS between ground truth and generated videos.
    
    Args:
        gt_path: Path to the ground truth frames folder
        generated_video_path: Path to the generated video file
        save_path: Optional path to save metrics results
    """
    # Load images
    gt_frames = load_images_from_folder(gt_path)
    height, width = gt_frames[0].shape[:2]
    generated_frames = load_video_frames(generated_video_path, height, width)

    assert len(gt_frames) == len(generated_frames)
    
    
    # Initialize LPIPS
    # loss_fn = lpips.LPIPS(net='alex')
    loss_fn = lpips.LPIPS(net='vgg').to("cuda")
    
    # Initialize metrics storage
    psnr_values = []
    ssim_values = []
    lpips_values = []
    psnr_star_values = []

    concat_frames = []
    
    # Calculate metrics for each frame
    for gt, gen in tqdm(zip(gt_frames, generated_frames), total=len(gt_frames), desc="Calculating metrics"):
        
        psnr = calculate_psnr(gt, gen)
        psnr_star = calculate_psnr_star(gt, gen)
        ssim = calculate_ssim(gt, gen)
        lpips_value = calculate_lpips(gt, gen, loss_fn)
        
        psnr_values.append(psnr)
        psnr_star_values.append(psnr_star)
        ssim_values.append(ssim)
        lpips_values.append(lpips_value)

        concat_frames.append(np.concatenate([gen, gt], axis=1))
    
    # Calculate mean values
    mean_psnr = np.mean(psnr_values)
    mean_ssim = np.mean(ssim_values)
    mean_lpips = np.mean(lpips_values)
    mean_psnr_star = np.mean(psnr_star_values)
    metrics_dict = {
        "PSNR": mean_psnr,
        "PSNR_STAR": mean_psnr_star,
        "SSIM": mean_ssim,
        "LPIPS": mean_lpips,
        
    }
    
    if visualize:
        save_video_path = save_path.replace('.json', '.mp4')
        with imageio.get_writer(save_video_path, fps=24) as writer:
            for frame in concat_frames:
                writer.append_data(frame)

    # Save results if save_path is provided
    if save_path:
        # Change file extension to .json
        save_path = os.path.splitext(save_path)[0] + '.json'
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        
        # Format the values to 4 decimal places
        formatted_metrics = {k: float(f"{v:.4f}") for k, v in metrics_dict.items()}
        
        with open(save_path, 'w') as f:
            json.dump(formatted_metrics, f, indent=4)
    
    return metrics_dict

if __name__ == "__main__":
    # Example usage
    evaluation_dir = "./evaluations"
    gt_path = './data/V-SDE/test'
    test_seq = sorted(os.listdir(gt_path))
    save_path = os.path.join(evaluation_dir, "metrics_results.json")  # Changed extension to .json
    all_metrics = []
    for seq in test_seq:
        seq_evaluation_dir = os.path.join(evaluation_dir, 'generated', seq+'_generated.mp4')
        seq_gt_path = os.path.join(gt_path, seq, 'normal')
    
    
        metrics = evaluate_metrics(seq_gt_path, seq_evaluation_dir)
        print(f"\nMetrics Results for {seq}:")
        for metric, value in metrics.items():
            print(f"{metric}: {value:.4f}")
        all_metrics.append(metrics)
    # average metrics
    avg_metrics = {k: np.mean([d[k] for d in all_metrics]) for k in all_metrics[0]}
    print("\nAverage Metrics:")
    for metric, value in avg_metrics.items():
        print(f"{metric}: {value:.4f}")

    # save average metrics
    with open(save_path, 'w') as f:
        json.dump(avg_metrics, f, indent=4)

