from MyDiffusers import StableDiffusionManager
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import save_image
import torchvision.transforms.functional as TF
import os
from copy import deepcopy
from MyDiffusers import *
from flowsrepo import *
import time
from PIL import Image
import numpy as np
import cv2
from diffusers import StableDiffusionPipeline
import argparse
from glob import glob
import pandas as pd
import plotly.express as px
import seaborn as sns

device = 'cuda:3'

image_args = {
    'pyr_scale' : 0.5,
    'levels' : 3,
    'winsize' : 15,
    'iterations' : 3,
    'poly_n' : 5,
    'poly_sigma' : 1.2,
}

latent_args = {
    'pyr_scale' : 0.6,
    'levels' : 1,
    'winsize' : 6,
    'iterations' : 2,
    'poly_n' : 4,
    'poly_sigma' : 2.5,
}

def compute_flow(frameA : np.ndarray, frameB : np.ndarray):
    return cv2.calcOpticalFlowFarneback(
        prev=frameA,
        next=frameB,
        flow=None,
        flags=0,
        **latent_args if frameA.shape[-2] == 64 else image_args
    )

@torch.no_grad()
def compute_score(frameA : Image, frameB : Image):
    # Convert to grayscale
    frameA_gray = cv2.cvtColor(np.array(frameA), cv2.COLOR_RGB2GRAY)
    frameB_gray = cv2.cvtColor(np.array(frameB), cv2.COLOR_RGB2GRAY)
    
    # Estimate flow in RGB space
    flow_image = compute_flow(frameA_gray, frameB_gray)

    # Latent space
    frameA_latent = SDM.image_to_latent(frameA)
    frameB_latent = SDM.image_to_latent(frameB)
    _, data = SDM.partial_inversion(
        z = torch.cat([frameA_latent, frameB_latent], dim=0),
        prompt='',
        guidance_scale=0.0,
        num_inference_steps=100,
    )
    frameA_latent = data['latents'][-1,0].unsqueeze(0)
    frameB_latent = data['latents'][-1,1].unsqueeze(0)


    # Min max normalization between 0 and 255
    min_val = min(frameA_latent.min(), frameB_latent.min())
    max_val = max(frameA_latent.max(), frameB_latent.max())
    frameA_latent = (frameA_latent - min_val) / (max_val - min_val)
    frameB_latent = (frameB_latent - min_val) / (max_val - min_val)
    frameA_latent = np.array(frameA_latent.cpu().detach() * 255).astype(np.uint8())
    frameB_latent = np.array(frameB_latent.cpu().detach() * 255).astype(np.uint8())

    # Choose the second channel for latent space flow estimation
    frameA_latent = frameA_latent[0,1]
    frameB_latent = frameB_latent[0,1]

    flow_latent = compute_flow(frameA_latent, frameB_latent)
    # Upsample the flow_latent to the size of the image
    flow_latent = cv2.resize(flow_latent, (512, 512), interpolation=cv2.INTER_LINEAR)

    # Express the flows in homogeneous coordinates
    tmp_flow_image = np.concatenate([flow_image, np.ones_like(flow_image[...,0:1])], axis=-1)
    tmp_flow_latent = np.concatenate([flow_latent, np.ones_like(flow_latent[...,0:1])], axis=-1)

    # Compute the cosine similarity between the two flows
    cos_sim = (tmp_flow_image * tmp_flow_latent).sum(-1) / (np.linalg.norm(tmp_flow_image, axis=-1) * np.linalg.norm(tmp_flow_latent, axis=-1))

    extra = {
        'flow_image' : flow_image,
        'flow_latent' : flow_latent,
        'cos_sim' : cos_sim,
        'frameA_latent' : frameA_latent,
        'frameB_latent' : frameB_latent,
    }

    return cos_sim.mean(), extra

def preprocess_image(image):
    # Square center crop
    image = TF.center_crop(image, min(image.size))
    # Resize
    image = TF.resize(image, 512)
    return image

# color code the flow to rgb
def flow_to_rgb(flow):
    hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
    hsv[..., 1] = 255
    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
    hsv[..., 0] = ang * 180 / np.pi / 2
    # hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    hsv[..., 2] = mag*10 #cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)


if os.path.exists('correlation_results.csv'):
    df = pd.read_csv('correlation_results.csv')
else:
    SDM = StableDiffusionManager(device=device, tau=400)
    df = []
    for video in sorted(glob('video_dataset/8x_interpolation/*')):
    # for video in sorted(glob('video_dataset/frames/*')):
        if not os.path.isdir(video):
            continue
        # if '007' not in video:
        #     continue
        all_frames = sorted(glob(video + '/*'))
        print(all_frames)
        # all_frames = [all_frames[0], all_frames[2]]
        all_frames = [Image.open(frame) for frame in all_frames]
        all_frames = [preprocess_image(frame) for frame in all_frames]
        for n, (frameA, frameB) in enumerate(zip(all_frames[:-1], all_frames[1:])):
            corr, extra = compute_score(frameA, frameB)
            df.append({
                'video' : int(video.split('/')[-1]),
                'frame' : n,
                'correlation' : corr
            })


            os.makedirs(f'flows/{video}', exist_ok=True)
            # Save the flow images
            d=32
            # scale=1/4
            scale=1

            flow_image = extra['flow_image']
            flow_latent = extra['flow_latent']*8
            frameA_latent = cv2.resize(extra['frameA_latent'], (512, 512), interpolation=cv2.INTER_NEAREST)
            frameB_latent = cv2.resize(extra['frameB_latent'], (512, 512), interpolation=cv2.INTER_NEAREST)
            simmap = extra['cos_sim']
            simmap = np.ma.masked_where(np.linalg.norm(flow_image, ord=2, axis=-1)*np.linalg.norm(flow_latent, ord=2, axis=-1) < 0.5,simmap)
            print(type(simmap), simmap.shape, simmap.min(), simmap.max())


            gridspec = {'width_ratios': [1, 1, 1, 1, 1, 0.1]}
            fig, ax = plt.subplots(1, 6, figsize=(5*5, 5), gridspec_kw=gridspec)

            ax[0].imshow(frameA, vmin=0, vmax=255, interpolation=None)
            ax[0].set_axis_off()
            ax[0].set_aspect(1)


            ax[1].imshow(frameB, vmin=0, vmax=255, interpolation=None)
            ax[1].quiver(
                np.arange(0, flow_image.shape[1], d),
                np.arange(0, flow_image.shape[0], d),
                flow_image[::d, ::d, 0],
                flow_image[::d, ::d, 1],
                color='r',
                scale=scale,
                scale_units='xy',
                angles='xy',
                width=0.006,
            )
            ax[1].set_axis_off()
            ax[1].set_aspect(1)

            ax[2].imshow(frameA_latent, vmin=0, vmax=255, interpolation=None, cmap='gray')
            ax[2].set_axis_off()
            ax[2].set_aspect(1)

            ax[3].imshow(frameB_latent, vmin=0, vmax=255, interpolation=None, cmap='gray')
            ax[3].quiver(
                np.arange(0, flow_latent.shape[1], d),
                np.arange(0, flow_latent.shape[0], d),
                flow_latent[::d, ::d, 0],
                flow_latent[::d, ::d, 1],
                color='r',
                scale=scale,
                scale_units='xy',
                angles='xy',
                width=0.006,
            )
            ax[3].set_axis_off()
            ax[3].set_aspect(1)

            im = ax[4].imshow(simmap, vmin=-1, vmax=1, interpolation=None, cmap='RdYlGn')
            ax[4].imshow(frameB, vmin=0, vmax=255, interpolation=None, alpha=0.5, cmap='gray')
            ax[4].set_axis_off()
            ax[4].set_aspect(1)

            plt.colorbar(im, cax=ax[5])
            plt.tight_layout()
            plt.savefig(f'flows/{video}/{n}_comparison.png')
            plt.close()

            flow_image = Image.fromarray(flow_to_rgb(extra['flow_image']))
            flow_latent = Image.fromarray(flow_to_rgb(extra['flow_latent']*8))
            flow_image.save(f'flows/{video}/{n}_flow_image.png')
            flow_latent.save(f'flows/{video}/{n}_flow_latent.png')

    df = pd.DataFrame(df)
    df.to_csv('correlation_results.csv', index=False)


print('Average correlation:', df['correlation'].mean())
fig = sns.violinplot(x='video', y='correlation', data=df)
plt.savefig('correlation_violin.png')
plt.close()
