import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from sklearn.cluster import DBSCAN
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
from scipy.ndimage import zoom
import os
import pandas as pd
import pyarrow.parquet as pq
import ast
import json
from typing import List, Dict
from itertools import combinations
import base64
from PIL import Image
from io import BytesIO
import io
from scipy.stats import entropy
from scipy.ndimage import gaussian_filter
from scipy.ndimage import uniform_filter
from scipy.ndimage import median_filter
import torch

def accept_process(start_k, end_k, attention, inputs, img_start, img_end, sig):
    accept_att = {}
    noise_token_num = 8
    noise_mean = [[0 for k in range(noise_token_num)] for i in range(len(inputs["image_grid_thw"]))]
    for k in range(start_k,end_k-2):
        max_att_sum = 0
        per_img_attention = []
        for img_idx in range(len(inputs["image_grid_thw"])):
            image_grid_thw = inputs["image_grid_thw"][img_idx]
            start = img_start[img_idx]
            end = img_end[img_idx]
            if start_k < end:
                start_k = end+1
            layer_sum = []
            layer_mean = []
            for i in range(len(attention)):
                k_att_map = np.array([row[k] for row in attention[i][0]])
                att_map = k_att_map[:,start:end].reshape(-1, image_grid_thw[1]//2,image_grid_thw[2]//2).mean(axis=0)
                layer_mean.append(att_map)
            per_img_attention.append(np.array(layer_mean).mean(axis=0,keepdims=True))
        max_att_get = 0
        for i in range(len(per_img_attention)):
            sum_per_img_att = per_img_attention[i].max()
            if sum_per_img_att > max_att_get:
                max_att_get = sum_per_img_att
                img_idx = i
            if k < start_k+noise_token_num:
                per_att = per_img_attention[i]
                if sig > 0:
                    per_att = gaussian_filter(per_att, sigma=sig)
                per_att = per_att - per_att.min()
                per_att = per_att / per_att.max()
                noise_mean[i][k-start_k] = per_att
        if k < start_k+noise_token_num: continue
        if not img_idx in accept_att:
            accept_att[img_idx] = {}
        accept_s = per_img_attention[img_idx]
        if sig > 0:
            accept_s = gaussian_filter(accept_s, sigma=sig)
        accept_s = accept_s - accept_s.min()
        accept_s = accept_s / accept_s.max()
        if noise_token_num > 0:
            accept_s -= np.array(noise_mean[img_idx]).mean(axis=0)
            accept_s[accept_s<0] = 0
        if accept_s.max() == 0: continue
        accept_s = accept_s - accept_s.min()
        accept_s = accept_s / accept_s.max()
        accept_att[img_idx][k]=accept_s
    return accept_att

def create_directory(path):
    """
    创建给定路径的目录，包括所有必要的父目录。

    :param path: 完整的目录路径字符串
    """
    try:
        os.makedirs(path, exist_ok=True)
        print(f"Directory created successfully at {path}")
    except Exception as e:
        print(f"Failed to create directory at {path}: {e}")

def crop_show_attbox(image,Ambbox,imgidx,word,boxidx,upscale_factor=2):

    raw_box = (Ambbox[0]*image.shape[1], Ambbox[1]*image.shape[0], Ambbox[2]*image.shape[1], Ambbox[3]*image.shape[0])
    raw_box = [round(x) for x in raw_box]
    x_min, y_min, x_max, y_max = raw_box
    crop_img = image[y_min:y_max, x_min:x_max]
    pil_image = Image.fromarray(crop_img)
    new_size = (round(crop_img.shape[1] * upscale_factor), round(crop_img.shape[0] * upscale_factor))
    pil_image = pil_image.resize(new_size, Image.BILINEAR)
    return pil_to_base64(pil_image)

def Add_box_border(mbbox, radius=0.05):
    x0 = 0 if mbbox[0] - radius < 0 else mbbox[0] - radius
    y0 = 0 if mbbox[1] - radius < 0 else mbbox[1] - radius
    x1 = 1 if mbbox[2] + radius > 1 else mbbox[2] + radius
    y1 = 1 if mbbox[3] + radius > 1 else mbbox[3] + radius
    return (x0, y0, x1, y1)


from collections import defaultdict

def load_json_to_list(json_path: str) -> List[Dict]:
    """
    加载 JSON 文件并返回一个由字典组成的列表
    
    参数:
        json_path (str): JSON 文件路径
    
    返回:
        List[Dict]: 列表中的每个元素都是一个字典
    """
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    if not isinstance(data, list):
        raise ValueError("JSON 文件内容不是一个列表")

    return data

def serialize_dict(my_dict, file_path):
    """
    将一个字典序列化为一行 JSON，追加写入到 .jsonl 文件。
    
    每次调用写入一行，不换行嵌套，符合 JSONL 标准。
    
    参数:
        my_dict: 要写入的字典（可能包含 ndarray、np.int64 等）
        file_path: 输出的 .jsonl 文件路径
    """
    def serialize_obj(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.int64, np.int32, np.float64, np.float32)):
            return obj.item()
        elif isinstance(obj, dict):
            return {key: serialize_obj(value) for key, value in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [serialize_obj(item) for item in obj]
        else:
            return obj

    # 序列化整个字典
    serialized_dict = serialize_obj(my_dict)

    # 追加写入一行 JSON
    with open(file_path, 'a', encoding='utf-8') as f:
        f.write(json.dumps(serialized_dict, ensure_ascii=False, indent=4) + '\n')

import base64

def image_to_base64(file_path):
    with open(file_path, "rb") as image_file:
        encoded_str = base64.b64encode(image_file.read()).decode("utf-8")
    return f"data:image;base64,{encoded_str}"

def pil_to_base64(pil_img, format="PNG"):
    buffered = BytesIO()
    # 如果 pil_img.format 不存在，使用指定的默认格式
    img_format = pil_img.format if pil_img.format else format
    pil_img.save(buffered, format=img_format)  # 使用指定格式保存图像到内存
    encoded_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image;base64,{encoded_str}"

def swap_and_rebuild_dict(nested_dict):
    """
    将两层嵌套字典的内外层 key 对调。
    
    输入:
        nested_dict: 形如 {outer_key: {inner_key: value}}
    输出:
        new_dict: 形如 {inner_key: {outer_key: value}}
    """
    new_dict = {}

    for outer_key, inner_dict in nested_dict.items():
        for inner_key, value in inner_dict.items():
            if inner_key not in new_dict:
                new_dict[inner_key] = {}
            new_dict[inner_key][outer_key] = value
            
    return dict(sorted(new_dict.items()))

def split_and_resize_image(img_base64, grid_size=2, upscale_factor=2):
    """
    将 base64 图像分割为 grid_size x grid_size 格，并放大每个格子
    
    参数:
        img_base64 (str): base64 编码的图像数据
        grid_size (int): 分割成 grid_size x grid_size 的格子，默认 2x2
        upscale_factor (int): 放大倍数，默认 2 倍
    
    返回:
        List[str]: 每个放大后的 base64 图像字符串（共 grid_size^2 个）
    """
    # Step 1: 解码 base64 图像
    if ',' in img_base64:
        img_base64 = img_base64.split(',')[1]
    image_data = base64.b64decode(img_base64)
    image = Image.open(BytesIO(image_data))
    img_array = np.array(image)

    H, W, C = img_array.shape

    # 计算每个格子的尺寸
    h_step = H // grid_size
    w_step = W // grid_size

    resized_images = []

    # Step 2: 切分并放大每个格子
    for i in range(grid_size):
        for j in range(grid_size):
            # 取出一个格子
            part = img_array[i*h_step : (i+1)*h_step, j*w_step : (j+1)*w_step, :]

            # 放大图像
            pil_img = Image.fromarray(part)
            new_size = (round(part.shape[1] * upscale_factor), round(part.shape[0] * upscale_factor))
            resized_img = pil_img.resize(new_size, Image.BILINEAR)

            # 创建一个内存中的字节流
            buffered = BytesIO()
            # 保存为 PNG 格式（无损）
            resized_img.save(buffered, format="PNG")

            # 获取 Base64 数据
            base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

            # 添加到列表中
            resized_images.append(f"data:image;base64,{base64_str}")

    return resized_images

import jsonlines
def load_dataset_Vstar_json(path):
    Vstar_list = []
    with open(path, 'r', encoding='utf-8') as f:
        Vstar_list = json.load(f)
    mmetype_Vstarbench = []
    for i in range(len(Vstar_list)):
        dict_i = {}
        dict_i["id"] = Vstar_list[i]["id"]
        dict_i["Text"] = Vstar_list[i]["question"]
        dict_i["Ground truth"] = Vstar_list[i]["labels"]
        dict_i["image"] = Vstar_list[i]["image_path"]
        if "box_json" in Vstar_list[i]:
            dict_i["box_json"] = Vstar_list[i]["box_json"]
        dict_i["category"] = Vstar_list[i]["category"]
        mmetype_Vstarbench.append(dict_i)
    return mmetype_Vstarbench

def flatten_and_sort_boxes(img_merged_boxes, y_threshold=10):
    """
    将 img_merged_boxes 展开，并按从上到下、从左到右排序。
    
    Args:
        img_merged_boxes: dict, 结构为 {word: {imgidx: [bbox1, bbox2, ...]}}
        y_threshold: float, 判定为同一行的 y 中心距离阈值

    Returns:
        sorted_list: list of tuple (word, imgidx, boxidx)
    """
    # 第一步：展开为列表
    entries = []
    for word, img_dict in img_merged_boxes.items():
        for imgidx, bboxes in img_dict.items():
            for boxidx, bbox in enumerate(bboxes):
                x0, y0, x1, y1 = bbox
                center_x = (x0 + x1) / 2
                center_y = (y0 + y1) / 2
                entries.append((word, imgidx, boxidx, center_x, center_y, bbox))
    
    # 第二步：按 y_center 排序，然后用阈值分组到“行”
    entries.sort(key=lambda x: (x[4], x[3]))  # 先按 y_center, 再按 x_center 排序

    # 分组到行
    lines = []
    current_line = []
    last_y = None

    for entry in entries:
        center_y = entry[4]
        if last_y is None or abs(center_y - last_y) <= y_threshold:
            current_line.append(entry)
        else:
            # 开始新行
            current_line.sort(key=lambda x: x[3])  # 行内按 x_center 排序
            lines.append(current_line)
            current_line = [entry]
        last_y = center_y

    # 添加最后一行
    if current_line:
        current_line.sort(key=lambda x: x[3])
        lines.append(current_line)

    # 展平所有行
    sorted_entries = []
    for line in lines:
        for entry in line:
            sorted_entries.append((entry[0], entry[1], entry[2]))  # (word, imgidx, boxidx)

    return sorted_entries