from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import torch.nn.functional as F
import cv2
import imageio
import argparse
import sys
import torch
import clip
import warnings
import numpy as np
sys.path.append('/share/home/xiajunhao/MM/Experiment/Baselines/RAVE/RAFT')
sys.path.append('/share/home/xiajunhao/MM/Experiment/Baselines/RAVE/RAFT/core')
from core.raft import RAFT
from core.utils.utils import InputPadder
from skimage.metrics import structural_similarity
import lpips
from torchvision import transforms
import utils.clip_loss as cl
from utils.ZSSGAN_text_templates import imagenet_templates, part_templates, imagenet_templates_small

def check_pil_list_len(pil_list,source_pil_list):
    if len(pil_list) != len(source_pil_list) or len(pil_list) > 1000:
        final_frame_len = min(len(pil_list),len(source_pil_list),800)
        pil_list_index = list(range(0, len(pil_list), len(pil_list) // final_frame_len))[:final_frame_len]
        pil_list = [pil_list[idx] for idx in pil_list_index]
        source_list_index = list(range(0, len(source_pil_list), len(source_pil_list) // final_frame_len))[:final_frame_len]
        source_pil_list = [source_pil_list[idx] for idx in source_list_index]
    # 检查每个 PIL 图像的尺寸是否一致
    for i in range(len(pil_list)):
        pil_img = pil_list[i]
        source_img = source_pil_list[i]
        # 如果尺寸不一致，将较大的图像调整为较小的尺寸
        if pil_img.size != source_img.size:
            target_size = min(pil_img.size, source_img.size)  # 取较小的尺寸
            if pil_img.size != target_size:
                pil_list[i] = pil_img.resize(target_size, Image.Resampling.LANCZOS)
            if source_img.size != target_size:
                source_pil_list[i] = source_img.resize(target_size, Image.Resampling.LANCZOS)
    return pil_list,source_pil_list

def video_to_pil_list(video_path):
    if video_path.endswith('.mp4'):
        vidcap = cv2.VideoCapture(video_path)
        pil_list = []
        while True:
            success, image = vidcap.read()
            if success:
                pil_list.append(Image.fromarray(image))
            else:
                break

        return pil_list
    elif video_path.endswith('.gif'):
        gif = imageio.get_reader(video_path)
        pil_list = []

        for frame in gif:
            pil_list.append(Image.fromarray(frame))

        return pil_list


def coords_grid(b, h, w, homogeneous=False, device=None):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))  # [H, W]

    stacks = [x, y]

    if homogeneous:
        ones = torch.ones_like(x)  # [H, W]
        stacks.append(ones)

    grid = torch.stack(stacks, dim=0).float()  # [2, H, W] or [3, H, W]

    grid = grid[None].repeat(b, 1, 1, 1)  # [B, 2, H, W] or [B, 3, H, W]

    if device is not None:
        grid = grid.to(device)

    return grid


def bilinear_sample(img,
                    sample_coords,
                    mode='bilinear',
                    padding_mode='zeros',
                    return_mask=False):
    # img: [B, C, H, W]
    # sample_coords: [B, 2, H, W] in image scale
    if sample_coords.size(1) != 2:  # [B, H, W, 2]
        sample_coords = sample_coords.permute(0, 3, 1, 2)

    b, _, h, w = sample_coords.shape

    # Normalize to [-1, 1]
    x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
    y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1

    grid = torch.stack([x_grid, y_grid], dim=-1)  # [B, H, W, 2]

    img = F.grid_sample(img,
                        grid,
                        mode=mode,
                        padding_mode=padding_mode,
                        align_corners=True)

    if return_mask:
        mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (
            y_grid <= 1)  # [B, H, W]

        return img, mask

    return img


def flow_warp_rerender(feature,
              flow,
              mask=False,
              mode='bilinear',
              padding_mode='zeros'):
    b, c, h, w = feature.size()
    assert flow.size(1) == 2

    grid = coords_grid(b, h, w).to(flow.device) + flow  # [B, 2, H, W]

    return bilinear_sample(feature,
                           grid,
                           mode=mode,
                           padding_mode=padding_mode,
                           return_mask=mask)


def clip_text(pil_list, text_prompt, preprocess, device, model):
    text = clip.tokenize([text_prompt]).to(device)

    scores = []
    images = []
    with torch.no_grad():
        text_features = model.encode_text(text)
        for pil in pil_list:
            image = preprocess(pil).unsqueeze(0).to(device)
            images.append(image)
        image_features = model.encode_image(torch.cat(images))
        scores = [torch.cosine_similarity(text_features, image_feature).item() for image_feature in image_features]

    score = sum(scores) / len(scores)
    
    return score

def clip_frame(pil_list, preprocess, device, model):
    image_features = []
    images = []
    with torch.no_grad():
        for pil in pil_list:
            image = preprocess(pil).unsqueeze(0).to(device)
            images.append(image)
        
        image_features = model.encode_image(torch.cat(images))
        
    image_features = image_features.cpu().numpy()
    cosine_sim_matrix = cosine_similarity(image_features)
    np.fill_diagonal(cosine_sim_matrix, 0)  # set diagonal elements to 0
    score = cosine_sim_matrix.sum() / (len(pil_list) * (len(pil_list)-1))

    return score

def pick_score_func(frames, prompt, model, processor, device):
    image_inputs = processor(images=frames, padding=True, truncation=True, max_length=77, return_tensors="pt").to(device)
    text_inputs = processor(text=prompt, padding=True, truncation=True, max_length=77, return_tensors="pt").to(device)

    with torch.no_grad():
        image_embs = model.get_image_features(**image_inputs)
        image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
        text_embs = model.get_text_features(**text_inputs)
        text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
        score_per_image = model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
        score_per_image = score_per_image.detach().cpu().numpy()
        score = score_per_image.mean()

    return score

def lpips_score_func(frames,src_frames,model):
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将 PIL 图像转为 [0, 1] 范围的 Tensor
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将其归一化到 [-1, 1] 范围
    ])
    tensor_frames = torch.stack([transform(frame) for frame in frames[:]])
    tensor_src_frames = torch.stack([transform(frame) for frame in src_frames[:]])
    # print(tensor_frames)
    if(torch.cuda.is_available()):
        tensor_frames = tensor_frames.cuda()
        tensor_src_frames = tensor_src_frames.cuda()
    # Compute distance
    dist01 = model.forward(tensor_frames, tensor_src_frames)
    return float(dist01.reshape(-1).mean(axis=0))

def interpolation_error_PSNR_fuc(frames):
    '''
    interpolate frame from the former frame and later frame, calculating the error between the interpolate and target frame
    '''
    total_mse = 0
    total_PSNR = 0
    num_pairs = len(frames) - 2
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将 PIL 图像转为 [0, 1] 范围的 Tensor
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将其归一化到 [-1, 1] 范围
    ])
    frames = torch.stack([transform(frame) for frame in frames[:]])
    for i in range(1,num_pairs+1):
        frame1 = frames[i-1]
        target_frame = frames[i]
        frame2 = frames[i + 1]
        # # 生成插值帧
        t = 0.5
        interpolated_frame = (1 - t) * frame1 + t * frame2
        # 计算插值帧与原始帧之间的MSE
        mse = torch.sqrt(torch.mean((target_frame - interpolated_frame) ** 2))
        psnr = 10 * np.log10(1 / (mse) ** 2)
        # print(mse)
        total_mse += mse.item()
        total_PSNR += psnr

    # 返回整体Interpolation Error（平均MSE）
    return total_mse / num_pairs, float(total_PSNR) / num_pairs

def DiCLIP_get_text_features(model, class_str: str,templates=imagenet_templates, norm: bool = True, device:str='cuda') -> torch.Tensor:
    template_text = [template.format(class_str) for template in templates]

    tokens = clip.tokenize(template_text).to(device)

    text_features = model.encode_text(tokens).detach()

    if norm:
        text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features

def DiCLIP_compute_text_direction(model, source_class: str, target_class: str) -> torch.Tensor:
    source_features = DiCLIP_get_text_features(model,source_class)
    target_features = DiCLIP_get_text_features(model,target_class)

    text_direction = (target_features - source_features).mean(axis=0, keepdim=True)
    text_direction /= text_direction.norm(dim=-1, keepdim=True)

    return text_direction

def DiCLIP_get_image_features(model,preprocess,img: torch.Tensor, norm: bool = True,device:str='cuda') -> torch.Tensor:
    img_preprocess = transforms.Compose(
        # [transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
                                              preprocess.transforms[:2] +                                      # to match CLIP input scale assumptions
                                              preprocess.transforms[4:])                                       # + skip convert PIL to tensor
    images = img_preprocess(img).to(device)
    image_features = model.encode_image(images)
    
    if norm:
        image_features /= image_features.clone().norm(dim=-1, keepdim=True)

    return image_features


def clip_directional_loss(model,preprocess,src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:

    target_direction = None
    if target_direction is None:
        target_direction = DiCLIP_compute_text_direction(model,source_class, target_class)

    src_encoding    = DiCLIP_get_image_features(model,preprocess,src_img)
    target_encoding = DiCLIP_get_image_features(model,preprocess,target_img)

    edit_direction = (target_encoding - src_encoding)
    if edit_direction.sum() == 0:
        target_encoding = DiCLIP_get_image_features(model,preprocess,target_img + 1e-6)
        edit_direction = (target_encoding - src_encoding)

    edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True))
    cos = torch.nn.CosineSimilarity()
    return cos(edit_direction, target_direction).mean()

def directional_clip_func(model,preprocess,tar_frames,tar_text:str,src_frames,src_text:str):
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将 PIL 图像转为 [0, 1] 范围的 Tensor
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 将其归一化到 [-1, 1] 范围
    ])
    tensor_tar_frames = torch.stack([transform(frame) for frame in tar_frames[:]])
    tensor_src_frames = torch.stack([transform(frame) for frame in src_frames[:]])
    directional_scores = clip_directional_loss(model,preprocess,tensor_src_frames,src_text,tensor_tar_frames,tar_text)
    return float(directional_scores.mean(dim=0))




def prepare_raft_model(device):
    raft_dict = {
        'model': '/share/home/xiajunhao/MM/Experiment/PretrainedModels/RAFT/raft-things.pth',
        'small': False,
        'mixed_precision': False,
        'alternate_corr': False
    }

    args = argparse.Namespace(**raft_dict)

    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(device)
    model.eval()

    return model

def flow_warp(img: np.ndarray,
              flow: np.ndarray,
              filling_value: int = 0,
              interpolate_mode: str = 'nearest'):
    '''Use flow to warp img.

    Args:
        img (ndarray): Image to be warped.
        flow (ndarray): Optical Flow.
        filling_value (int): The missing pixels will be set with filling_value.
        interpolate_mode (str): bilinear -> Bilinear Interpolation;
                                nearest -> Nearest Neighbor.

    Returns:
        ndarray: Warped image with the same shape of img
    '''
    warnings.warn('This function is just for prototyping and cannot '
                  'guarantee the computational efficiency.')
    assert flow.ndim == 3, 'Flow must be in 3D arrays.'
    height = flow.shape[0]
    width = flow.shape[1]
    channels = img.shape[2]

    output = np.ones(
        (height, width, channels), dtype=img.dtype) * filling_value

    grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
    dx = grid[:, :, 0] + flow[:, :, 1]
    dy = grid[:, :, 1] + flow[:, :, 0]
    sx = np.floor(dx).astype(int)
    sy = np.floor(dy).astype(int)
    valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)

    if interpolate_mode == 'nearest':
        output[valid, :] = img[dx[valid].round().astype(int),
                               dy[valid].round().astype(int), :]
    elif interpolate_mode == 'bilinear':
        # dirty walkround for integer positions
        eps_ = 1e-6
        dx, dy = dx + eps_, dy + eps_
        left_top_ = img[np.floor(dx[valid]).astype(int),
                        np.floor(dy[valid]).astype(int), :] * (
                            np.ceil(dx[valid]) - dx[valid])[:, None] * (
                                np.ceil(dy[valid]) - dy[valid])[:, None]
        left_down_ = img[np.ceil(dx[valid]).astype(int),
                         np.floor(dy[valid]).astype(int), :] * (
                             dx[valid] - np.floor(dx[valid]))[:, None] * (
                                 np.ceil(dy[valid]) - dy[valid])[:, None]
        right_top_ = img[np.floor(dx[valid]).astype(int),
                         np.ceil(dy[valid]).astype(int), :] * (
                             np.ceil(dx[valid]) - dx[valid])[:, None] * (
                                 dy[valid] - np.floor(dy[valid]))[:, None]
        right_down_ = img[np.ceil(dx[valid]).astype(int),
                          np.ceil(dy[valid]).astype(int), :] * (
                              dx[valid] - np.floor(dx[valid]))[:, None] * (
                                  dy[valid] - np.floor(dy[valid]))[:, None]
        output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
    else:
        raise NotImplementedError(
            'We only support interpolation modes of nearest and bilinear, '
            f'but got {interpolate_mode}.')
    return output.astype(img.dtype)

def calculate_flow(pil_list, model, DEVICE):
    def load_image(imfile, DEVICE):
        img = np.array(imfile).astype(np.uint8)
        img = torch.from_numpy(img).permute(2, 0, 1).float()
        return img[None].to(DEVICE)

    flow_up_list = []
    with torch.no_grad():
        images = pil_list.copy()
        for imfile1, imfile2 in zip(images[:-1], images[1:]):
            image1 = load_image(imfile1, DEVICE)
            image2 = load_image(imfile2, DEVICE)

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1, image2)

            _, flow_up = model(image1, image2, iters=20, test_mode=True)

            flow_up_list.append(flow_up.detach().squeeze().permute(1,2,0).cpu().numpy())
    return flow_up_list

def rerender_warp(img, flow, mode='bilinear'):
    expand = False
    if len(img.shape) == 2:
        expand = True
        img = np.expand_dims(img, 2)

    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    dtype = img.dtype
    img = img.to(torch.float)
    flow = torch.from_numpy(flow).permute(2, 0, 1).unsqueeze(0)
    res = flow_warp_rerender(img, flow, mode=mode)
    res = res.to(dtype)
    res = res[0].cpu().permute(1, 2, 0).numpy()
    if expand:
        res = res[:, :, 0]
    return res

def opencv_warp(img, flow):

    h, w = flow.shape[:2]
    flow[:,:,0] += np.arange(w)
    flow[:,:,1] += np.arange(h)[:,np.newaxis]
    warped_img = cv2.remap(img, flow, None, cv2.INTER_LINEAR)
    return warped_img

rearrange = lambda x: (np.array(x)/255).reshape(-1,1)

def warp_video(edit_pil_list, source_pil_list, raft_model, device, distance_func):
    # print('source size', source_pil_list[0].size)
    flow_up_list = calculate_flow(source_pil_list, raft_model, device)

    res_list = [edit_pil_list[0]]
    for i,pil_img in enumerate(edit_pil_list[:-1]):
        warped = opencv_warp(np.array(pil_img), flow_up_list[i])
        pil_warped = Image.fromarray(warped)
        # pil_warped.save(f'warped_{i}.png')
        res_list.append(pil_warped)
    # res_list[0].save('warped.gif', save_all=True, append_images=res_list[1:], duration=100, loop=0)
    # print('size of video', res_list[0].size)
    if distance_func == structural_similarity:
        return np.mean(np.array([distance_func(np.array(edit_pil_list[i]), np.array(res_list[i]), channel_axis=2) for i in range(len(res_list))]))
    else:
        return np.mean(np.array([distance_func(edit_pil_list[i], res_list[i]) for i in range(len(res_list))]))
    

def ssim_func(edit_pil_list,source_pil_list):
    return np.mean(np.array([structural_similarity(np.array(edit_pil_list[i]), np.array(source_pil_list[i]), channel_axis=2) for i in range(len(edit_pil_list))]))