import torch
from torch.nn import functional as F
import torchvision.transforms.functional as TVF
from PIL import Image
import numpy as np
from glob import glob
from transformers import CLIPProcessor, CLIPModel
import argparse
import os
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure
import cv2 as cv
import ptlflow
from ptlflow.utils.io_adapter import IOAdapter
from ptlflow.utils import flow_utils
import matplotlib.pyplot as plt


@torch.no_grad()
def estimate_flow(frameA, frameB):
    # Convert to tensor
    img1 = TVF.to_tensor(frameA)[[2, 1, 0], :, :].permute(1, 2, 0)
    img2 = TVF.to_tensor(frameB)[[2, 1, 0], :, :].permute(1, 2, 0)
    raftnet_io_adapter = IOAdapter(raftnet, img1.shape[:2])
    inputs = raftnet_io_adapter.prepare_inputs([img1, img2])
    inputs = {k: v.to(device) for k, v in inputs.items()}
    predictions = raftnet(inputs)
    flow = predictions['flows'][0, 0].cpu().detach()
    flow = flow.permute(1, 2, 0).numpy()
    return flow


def compute_frame_consistency(base_folder, regex_pattern, device):
    frames = glob(os.path.join(base_folder, regex_pattern))
    frames = sorted(frames)
    frames = [Image.open(frame) for frame in frames]

    all_similarities = {
        "cosine": [],
        "L1pixel": [],
        "lpips": [],
        "ssim": [],
        "lpips_flow": [],
        "ssim_flow": [],
        "cosine_weighted": [],
        "lpips_flow_weighted": [],
        "ssim_flow_weighted": [],
    }
    for frameA, frameB in zip(frames[:-1], frames[1:]):
        pair = [frameA, frameB]
        pair_torch = [TVF.to_tensor(frame).unsqueeze(0).to(device) for frame in pair]

        # Compute cosine similarity with CLIP embeddings
        clip = CLIP_model.get_image_features(
            **processor(images=pair, return_tensors="pt", padding=True).to(device)
        ).unbind(0)
        metric_cosine_clip = F.cosine_similarity(*clip, dim=-1).mean().item()
        metric_L1_pixels = F.l1_loss(*pair_torch).item()
        metric_lpips = lpips_model(
            pair_torch[0] * 2 - 1,
            pair_torch[1] * 2 - 1,
        ).mean().item()
        metric_ssim = ssim(
            pair_torch[0],
            pair_torch[1],
        ).mean().item()


        all_similarities["cosine"].append(metric_cosine_clip)

        # Compute L1 pixel similarity
        all_similarities["L1pixel"].append(metric_L1_pixels)

        # Compute LPIPS similarity
        all_similarities["lpips"].append(metric_lpips)

        # Compute SSIM similarity
        all_similarities["ssim"].append(metric_ssim)

        # Compute flow-adjusted LPIPS similarity
        flow = estimate_flow(frameA, frameB)
        # Warp back the frame
        XY = np.meshgrid(np.arange(flow.shape[1]), np.arange(flow.shape[0]))
        XY = np.stack(XY, axis=-1)
        flow_abs = XY + flow
        flow_abs = flow_abs.astype(np.float32)
        flow_abs = torch.tensor(flow_abs).to(device)
        flow_abs = (flow_abs / 512.0) * 2 - 1
        frameA_hat = F.grid_sample(
            pair_torch[1],
            flow_abs.unsqueeze(0),
            mode="bilinear",
            padding_mode="zeros",
            align_corners=True,
        )

        all_similarities["cosine_weighted"].append(metric_cosine_clip * torch.tensor(flow).norm(dim=-1).mean().item())

        metric_lpips_flow = lpips_model(
            pair_torch[0] * 2 - 1,
            frameA_hat * 2 - 1,
        ).mean().item()

        metric_ssim_flow = ssim(
            pair_torch[0],
            frameA_hat,
        ).mean().item()


        all_similarities["lpips_flow"].append(metric_lpips_flow)
        all_similarities["lpips_flow_weighted"].append(metric_lpips_flow * torch.tensor(flow).norm(dim=-1).mean().item())

        all_similarities["ssim_flow"].append(metric_ssim_flow)
        all_similarities["ssim_flow_weighted"].append(metric_ssim_flow * torch.tensor(flow).norm(dim=-1).mean().item())

    for k in all_similarities.keys():
        all_similarities[k] = round(np.mean(all_similarities[k]),4)

    return all_similarities


if __name__ == "__main__":
    device = "cuda:2"

    CLIP_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    CLIP_model = CLIP_model.to(device)

    lpips_model = LearnedPerceptualImagePatchSimilarity(reduction="mean").to(device)
    lpips_model.eval()

    ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    raftnet = ptlflow.get_model('raft_small', 'things')
    raftnet = raftnet.eval().to(device)


    # Text-2-Video-Zero
    for path in [
        "~/Text2Video-Zero/test-output/dragons",
        "~/Text2Video-Zero/test-output/meltingman",
        "~/Text2Video-Zero/test-output/satellite",
        "~/Text2Video-Zero/test-output/earth",
        "~/Text2Video-Zero/test-output/birds"
    ]:

        print(
            compute_frame_consistency(
                base_folder=path, regex_pattern="*.png", device=device
            )
        )

    # Ours
    for path in [
        "~/MyDiffusers/test-output/video/final/dragons",
        "~/MyDiffusers/test-output/video/final/meltingman",
        "~/MyDiffusers/test-output/video/final/satellite",
        "~/MyDiffusers/test-output/video/final/earth",
        "~/MyDiffusers/test-output/video/final/birds",
    ]:

        print(
            compute_frame_consistency(
                base_folder=path, regex_pattern="*_2.png", device=device
            )
        )