import math, os
import numpy as np
import scipy.ndimage as ndi
import unicodedata
import json
import re

from datasets import load_dataset, Dataset
from os.path import basename, dirname, join
from PIL import Image


QUESTION_TMPL = (
    "What is the coordinate of [{question}] in the image?\n"
    "The size of image is ({w},{h}).\n"
    "Output the thinking process in <think> </think> and "
    "final answer (coordinate (x,y)) in <answer> </answer> tags."
)


def extract_xy(txt):
    ANS_RE = re.compile(r"<answer>\s*\(?\s*(\d+)\s*,\s*(\d+)\s*\)?\s*</answer>", re.I)
    m = ANS_RE.search(txt)
    if not m:
        return (None, None)
    x, y = int(m.group(1)), int(m.group(2))
    return (x, y)


def in_box(xy, bbox, img_size=None):
    if xy[0] is None:
        return False
    if img_size: # relative
        W, H = img_size
        rx, ry = xy[0] / W, xy[1] / H
        return bbox[0] <= rx <= bbox[2] and bbox[1] <= ry <= bbox[3]
    else: # absolute
        return bbox[0] <= xy[0] <= bbox[2] and bbox[1] <= xy[1] <= bbox[3]


def to_list(xy):
    return [] if xy[0] is None else [int(xy[0]), int(xy[1])]


def compute_KDE(pts_xy, img_size, sigma_frac=0.01, truncate=4.0):
    W, H = img_size
    if not pts_xy:
        return (None, None)

    # find tight bbox
    xs, ys = zip(*pts_xy)
    xmin, xmax = min(xs), max(xs)
    ymin, ymax = min(ys), max(ys)

    # make a fit bmp
    w_sub = xmax - xmin + 1
    h_sub = ymax - ymin + 1
    bmp_sub = np.zeros((h_sub, w_sub), dtype=np.float32)
    for x, y in pts_xy:
        xi = x - xmin
        yi = y - ymin
        bmp_sub[yi, xi] += 1.0

    # blur with zero‐padding
    sigma = sigma_frac * min(W, H)
    heat_sub = ndi.gaussian_filter(
        bmp_sub,
        sigma=sigma,
        truncate=truncate,
        mode='constant',
        cval=0.0
    )

    # find peak in subwindow, translate back to global coords
    yi_sub, xi_sub = np.unravel_index(int(heat_sub.argmax()), heat_sub.shape)
    xi_glob = min(max(xmin + int(xi_sub), 0), W-1)
    yi_glob = min(max(ymin + int(yi_sub), 0), H-1)
    return (xi_glob, yi_glob)


# load and index annotations
def normalize_key(key: str) -> str:
    nk = unicodedata.normalize('NFKC', key) # Unicode normalize, replace narrow no-break spaces
    return nk.replace('\u202f', ' ')


def load_screenspot():
    DATASET   = 'rootsautomation/ScreenSpot'
    ds        = load_dataset(DATASET, split='test')
    return ds


def load_screenspot_pro():
    DATASET   = 'likaixin/ScreenSpot-Pro'
    raw_ds    = load_dataset(DATASET, split='train')

    class_names   = raw_ds.features['label'].names
    label_to_json = {i: f"{class_names[i]}.json" for i in range(len(class_names))}

    ANNOTATIONS_DIR = 'screenspot_pro_annotations'

    annotations = {}
    for lbl, fname in label_to_json.items():
        path = join(ANNOTATIONS_DIR, fname)
        with open(path, 'r', encoding='utf-8') as f:
            ann_list = json.load(f)
        idx_map = {}
        for rec in ann_list:
            raw = rec['img_filename']
            key = normalize_key(raw)
            idx_map[key] = rec
        annotations[lbl] = idx_map

    # merge image-folder dataset with the annotation records
    dataset_entries = []
    for ex in raw_ds:
        lbl       = ex['label']
        img_path  = ex['image'].filename
        dir_name  = basename(dirname(img_path)) # e.g. 'common_windows'
        file_name = basename(img_path)
        raw_key   = f"{dir_name}/{file_name}"
        key       = normalize_key(raw_key)

        ann = annotations.get(lbl, {}).get(key)
        if ann is None:
            continue # skip examples that don't have matching annotations

        dataset_entries.append({
            'image'      : ex['image'],
            'instruction': ann['instruction'],
            'bbox'       : ann['bbox'], # absolute pixels
            'ui_type'    : ann.get('ui_type'),
            'id'         : ann.get('id'),
            'application': ann.get('application'),
            'platform'   : ann.get('platform'),
        })

    ds = Dataset.from_list(dataset_entries)
    return ds


def load_android_control(image_dir, query_path):
    entries = []
    with open(query_path, 'r', encoding='utf-8') as f:
        for line in f:
            rec = json.loads(line)
            rel = rec.get('image')
            if not rel:
                print("No key")
                continue
            rel_norm = normalize_key(rel)
            img_path = join(image_dir, rel_norm)
            if not os.path.isfile(img_path):
                print("No image")
                continue
            img = Image.open(img_path).convert('RGB')

            entry = dict(rec)
            entry['screenshot'] = img
            entries.append(entry)

    return Dataset.from_list(entries)



def load_multimodal_mind2web(image_dir, query_path):
    entries = []
    with open(query_path, 'r', encoding='utf-8') as f:
        for line in f:
            rec = json.loads(line)
            rel = rec.get('image')
            if not rel:
                print("No key")
                continue
            rel_norm = normalize_key(rel)
            img_path = join(image_dir, rel_norm)
            if not os.path.isfile(img_path):
                print(f"No image")
                continue
            img = Image.open(img_path).convert('RGB')
            block = int(os.path.splitext(os.path.basename(rel_norm))[0])

            entry = dict(rec)
            entry['screenshot'] = img
            entry['ans_block'] = block
            entries.append(entry)

    return Dataset.from_list(entries)


def load_omniact(image_dir, query_path):
    entries = []
    with open(query_path, 'r', encoding='utf-8') as f:
        for line in f:
            rec = json.loads(line)
            rel = rec.get('image')
            if not rel:
                print("No key")
                continue
            rel_norm = normalize_key(rel)
            img_path = join(image_dir, rel_norm)
            if not os.path.isfile(img_path):
                print("No image")
                continue
            img = Image.open(img_path).convert('RGB')

            entry = dict(rec)
            entry['screenshot'] = img
            entries.append(entry)

    return Dataset.from_list(entries)


def load_screenspot_agent(query_path):
    ds = load_dataset('rootsautomation/ScreenSpot', split='test')

    pair_map = {}
    for i, fn in enumerate(ds['file_name']):
        key = normalize_key(fn)
        instr = ds[i].get('instruction')
        pair_map[(key, instr)] = i

    entries = []
    with open(query_path, 'r', encoding='utf-8') as f:
        for line in f:
            rec = json.loads(line)
            rel = rec.get('image')
            if not rel:
                print("No key in query JSON:", rec)
                continue

            key = normalize_key(rel)
            qry_instr = rec.get('instruction')
            idx = pair_map.get((key, qry_instr))
            if idx is None:
                print(f"No HF entry for (file={key!r}, instr={qry_instr!r})")
                continue

            img_field = ds[idx]['image']
            if isinstance(img_field, Image.Image):
                img = img_field
            else:
                img = Image.open(img_field).convert('RGB')

            entry = dict(rec)
            entry['screenshot'] = img
            entries.append(entry)

    return Dataset.from_list(entries)


def dump(path, recs):
    with open(path, 'w', encoding='utf-8') as f:
        for r in recs:
            f.write(json.dumps(r, ensure_ascii=False) + '\n')