import torch
import numpy as np


def numpy_to_torch(a: np.ndarray):
    return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)


import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import math, cv2
import os
import numpy as np

def draw_attn(seq_name, frame_num, x, attn):
    # attn: (1, 1, 16, 16)
    # attn_ = attn.reshape(1, 1, int(math.sqrt(attn.shape[-1])), -1)
    attn_L = F.interpolate(attn, size=(x.shape[0], x.shape[1]), mode='bilinear', align_corners=False).squeeze().cpu()

    plt.figure(2) 
    pred_frame = plt.gca()
    plt.imshow(attn_L.cpu(), 'jet')     
    pred_frame.axes.get_yaxis().set_visible(False)
    pred_frame.axes.get_xaxis().set_visible(False)       
    pred_frame.spines['top'].set_visible(False)
    pred_frame.spines['bottom'].set_visible(False)
    pred_frame.spines['left'].set_visible(False)
    pred_frame.spines['right'].set_visible(False)

    save_path = os.path.join("vis/attn", seq_name)

    os.makedirs(save_path, exist_ok=True)    
    plt.savefig(save_path + f"/{frame_num + 1:08d}.png", bbox_inches='tight', pad_inches=0, dpi=150)
    img = cv2.imread(save_path + f"/{frame_num + 1:08d}.png")

    image = cv2.cvtColor(cv2.resize(img, (x.shape[1], x.shape[0])), cv2.COLOR_BGR2RGB)

    blended_image = cv2.resize(cv2.addWeighted(x, 1, image, 0.5, 0),(540,540))

    font_path = "/usr/share/texmf/fonts/opentype/public/tex-gyre/texgyreheros-bold.otf"  # 你的字体路径
    font_pil = ImageFont.truetype(font_path, 77)  # 字体大小

    img_pil = Image.fromarray(blended_image)
    draw = ImageDraw.Draw(img_pil)
    draw.text((10, -5), f"#{frame_num+1}", font=font_pil, fill=(0, 255, 255))  # 画文字

    blended_image = np.array(img_pil)

    plt.imshow(blended_image)
    plt.axis("off")


    plt.savefig(save_path + f"/{frame_num + 1:08d}.png", bbox_inches='tight', pad_inches=0, dpi=150)

    plt.close(2) 

def draw_bbox(seq_name, frame_num, img, bbox): 

    x, y, w, h = bbox
    x_min, y_min = min(x, x + w), min(y, y + h)
    top_left = (int(x_min), int(y_min))  
    bottom_right = (int(x_min + abs(w)), int(y_min + abs(h)))

    cv2.rectangle(img, top_left, bottom_right, (0, 0, 255), 5)

    font_path = "/usr/share/texmf/fonts/opentype/public/tex-gyre/texgyreheros-bold.otf"  # 你的字体路径
    font_pil = ImageFont.truetype(font_path, 66)  # 字体大小

    img_pil = Image.fromarray(img)
    draw = ImageDraw.Draw(img_pil)
    # draw.text((20, -5), f"#{frame_num+1}", font=font_pil, fill=(0, 255, 255))  # 画文字

    img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    
    save_path = os.path.join("vis/bbox", seq_name)
    os.makedirs(save_path, exist_ok=True) 
    cv2.imwrite(os.path.join(save_path, f"{frame_num+1:08d}.png"), img)
        



    