import os, json
import argparse
import numpy as np
import random

### 
current_path = os.path.abspath(__file__)
DATA_DIR = os.path.dirname(current_path)
pic_base = os.path.join(DATA_DIR, 'screenshots')
anno_base = os.path.join(DATA_DIR, 'annotations')
# train_anno_base = os.path.join(DATA_DIR, 'train_anno')
test_anno_base = os.path.join(DATA_DIR, 'test_anno')
# split_base = os.path.join(DATA_DIR, 'splits')

PROMPT = "I'm looking for guidance on how to "

###

def decode_action(action, info):
    if action == 'CLICK' or action == "LONG_PRESS":
        if info == 'KEY_HOME':
            gt = 'PRESS_HOME'
        elif info == 'KEY_BACK':
            gt = 'PRESS_BACK'
        elif info == 'KEY_APPSELECT':
            gt = 'PRESS_RECENT'
        elif type(info) == list:
            gt = f"{action}: {tuple(info[0])}"
        else:
            raise ValueError(f'Unknown click action {info}')

    elif action == 'SCROLL':
        start = np.array(info[0])
        end = np.array(info[1])
        delta = end - start
        delta_abs = np.abs(delta)
        lr = 'LEFT' if delta[0] < 0 else 'RIGHT'
        ud = 'UP' if delta[1] < 0 else 'DOWN'
        if delta_abs[0] > delta_abs[1]:
            gt = f"SCROLL: {lr}"
        else:
            gt = f"SCROLL: {ud}"

    elif action == 'TEXT': 
        gt = f'TYPE: {info}'
    elif action == 'COMPLETE':
        gt = action
    elif action == 'INCOMPLETE':
        gt = 'IMPOSSIBLE'
    else:
        raise ValueError(f'Unknown action {action}')
    return gt


def _natural_key(fname: str):
    stem = os.path.splitext(fname)[0]
    return (0, int(stem)) if stem.isdigit() else (1, stem.lower())

def build_test(his_len=4, instr_level='high', out_name=None):
    os.makedirs(test_anno_base, exist_ok=True)
    files = [f for f in os.listdir(anno_base) if f.lower().endswith('.json')]
    files.sort(key=_natural_key)
    
    res = []
    idx = 0
    for f in files:
        fp = os.path.join(anno_base, f)
        try:
            data = json.load(open(fp, 'r', encoding='utf-8'))
        except Exception as e:
            print(f"[Warn] 跳过无法解析的 JSON: {fp} ({e})")
            continue
        
        task_info = data.get('task_info', {})
        high_level_instruction = task_info.get('instruction', '')
        category = task_info.get('category', None)

        steps = data.get('steps', [])
        step_length = data.get('step_length', len(steps))

        history_screenshot = []
        history_action = []

        for step in steps:
            image = step['screenshot']
            low_level_instruction = step['low_level_instruction']
            if instr_level == 'high':
                instruction = high_level_instruction
            elif instr_level == 'low':
                instruction = low_level_instruction
            img_abs_path = os.path.join(pic_base, image)
            action = step['action']
            info = step['info']
            sam2_bbox = step['sam2_bbox']
            gt = decode_action(action, info)

            res.append({
                'id': f'GUIOdyssey_{idx}',
                'image': img_abs_path,
                'question': instruction,
                'answer': gt,
                'category': category,
                'step_length': step_length,
                'history_action': str(history_action),
                'history_screenshot': str(history_screenshot),
                'sam2_bbox': sam2_bbox,
            })
            idx += 1

            history_screenshot.append(img_abs_path)
            history_action.append(gt)

    if out_name is None:
        out_name = f'{instr_level}_all.json'
    out_fp = os.path.join(test_anno_base, out_name)
    json.dump(res, open(out_fp, 'w'), indent=4, ensure_ascii=False)
    print(f"[Done] 共写入 {len(res)} 条测试样本 -> {out_fp}")


def make_his_idx(test_base=test_anno_base):
    savep = './his_index.json'
    his_dict = {}
            
    for subsplit in os.listdir(test_base):
        subp = os.path.join(test_base, subsplit)
        data_all = json.load(open(subp))
        for data in data_all:
            img = data.get('image')
            history = data.get('history_screenshot', '[]')
            try:
                history_list = eval(history)
            except Exception:
                history_list = []
            if img is None:
                continue
            if img in his_dict:
                # 若重复，检查一致性
                assert his_dict[img] == history_list
            else:
                his_dict[img] = history_list

    print(f"[Info] his_index 中共 {len(his_dict)} 个键")
    json.dump(his_dict, open(savep, 'w', encoding='utf-8'), indent=4, ensure_ascii=False)
    print(f"[Done] 写入 {savep}")
    
    
def main(args):
    build_test(his_len=args.his_len, instr_level=args.level, out_name=None)   
    make_his_idx()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--his_len', type=int, default=4)
    parser.add_argument('--level', type=str, choices=['high', 'low'], default='low')
    # parser.add_argument('--type', type=str, choices=['semantic', 'standard'], default='standard')
    args = parser.parse_args()
    main(args)