import torch
import numpy as np


def numpy_to_torch(a: np.ndarray):
    return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)


import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import math, cv2
import os
import numpy as np

def draw_attn(seq_name, frame_num, x, attn, save_name):
    # attn: (1, 1, 16, 16)
    # attn_ = attn.reshape(1, 1, int(math.sqrt(attn.shape[-1])), -1)
    attn_L = F.interpolate(attn, size=(x.shape[0], x.shape[1]), mode='bilinear', align_corners=False).squeeze().cpu()

    plt.figure(2) 
    pred_frame = plt.gca()
    plt.imshow(attn_L.cpu(), 'jet')     
    pred_frame.axes.get_yaxis().set_visible(False)
    pred_frame.axes.get_xaxis().set_visible(False)       
    pred_frame.spines['top'].set_visible(False)
    pred_frame.spines['bottom'].set_visible(False)
    pred_frame.spines['left'].set_visible(False)
    pred_frame.spines['right'].set_visible(False)

    save_path = os.path.join("vis/attn", seq_name)

    os.makedirs(save_path, exist_ok=True)    
    plt.savefig(save_path + f"/{frame_num + 1:08d}_{save_name}.png", bbox_inches='tight', pad_inches=0, dpi=150)
    img = cv2.imread(save_path + f"/{frame_num + 1:08d}_{save_name}.png")

    image = cv2.cvtColor(cv2.resize(img, (x.shape[1], x.shape[0])), cv2.COLOR_BGR2RGB)

    blended_image = cv2.resize(cv2.addWeighted(x, 1, image, 0.5, 0),(540,540))

    font_path = "/usr/share/texmf/fonts/opentype/public/tex-gyre/texgyreheros-bold.otf"  # 你的字体路径
    font_pil = ImageFont.truetype(font_path, 77)  # 字体大小

    img_pil = Image.fromarray(blended_image)
    draw = ImageDraw.Draw(img_pil)
    draw.text((10, -5), f"#{frame_num+1}", font=font_pil, fill=(0, 255, 255))  # 画文字

    blended_image = np.array(img_pil)

    plt.imshow(blended_image)
    plt.axis("off")


    plt.savefig(save_path + f"/{frame_num + 1:08d}_{save_name}.png", bbox_inches='tight', pad_inches=0, dpi=150)

    plt.close(2) 

def draw_bbox(seq_name, frame_num, img, bbox): 

    x, y, w, h = bbox
    x_min, y_min = min(x, x + w), min(y, y + h)
    top_left = (int(x_min), int(y_min))  
    bottom_right = (int(x_min + abs(w)), int(y_min + abs(h)))

    cv2.rectangle(img, top_left, bottom_right, (0, 0, 255), 5)

    font_path = "/usr/share/texmf/fonts/opentype/public/tex-gyre/texgyreheros-bold.otf"  # 你的字体路径
    font_pil = ImageFont.truetype(font_path, 66)  # 字体大小

    img_pil = Image.fromarray(img)
    draw = ImageDraw.Draw(img_pil)
    draw.text((20, -5), f"#{frame_num+1}", font=font_pil, fill=(0, 255, 255))  # 画文字

    img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    
    save_path = os.path.join("vis/bbox", seq_name)
    os.makedirs(save_path, exist_ok=True) 
    cv2.imwrite(os.path.join(save_path, f"{frame_num+1:08d}.png"), img)
        



    


############## used for visulize eliminated tokens #################
def get_keep_indices(decisions):
    keep_indices = []
    for i in range(3):
        if i == 0:
            keep_indices.append(decisions[i])
        else:
            keep_indices.append(keep_indices[-1][decisions[i]])
    return keep_indices


def gen_masked_tokens(tokens, indices, alpha=0.2):
    # indices = [i for i in range(196) if i not in indices]
    indices = indices[0].astype(int)
    tokens = tokens.copy()
    tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255
    return tokens


def recover_image(tokens, H, W, Hp, Wp, patch_size):
    # image: (C, 196, 16, 16)
    image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3)
    return image


def pad_img(img):
    height, width, channels = img.shape
    im_bg = np.ones((height, width + 8, channels)) * 255
    im_bg[0:height, 0:width, :] = img
    return im_bg


def gen_visualization(image, mask_indices, patch_size=16):
    # image [224, 224, 3]
    # mask_indices, list of masked token indices

    # mask mask_indices need to cat
    # mask_indices = mask_indices[::-1]
    num_stages = len(mask_indices)
    for i in range(1, num_stages):
        mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1)

    # keep_indices = get_keep_indices(decisions)
    image = np.asarray(image)
    H, W, C = image.shape
    Hp, Wp = H // patch_size, W // patch_size
    image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3)

    stages = [
        recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size)
        for i in range(num_stages)
    ]
    imgs = [image] + stages
    imgs = [pad_img(img) for img in imgs]
    viz = np.concatenate(imgs, axis=1)
    return viz
