
# -*- coding: utf-8 -*-
import json
from typing import List, Dict, Any
from collections import defaultdict
from preprocess_visual import file_paths


class TextDataProcessor:
    def __init__(self, model_name: str, dataset_type: str, file_paths: Dict[str, Dict[str, str]]):
        self.model_name = model_name
        self.dataset_type = dataset_type
        self.file_paths = file_paths[dataset_type][model_name]
        self.dataset_names = ['AndroidControl', 'AITZ', 'GUI_Odyssey']

    @staticmethod
    def read_json(path: str) -> Any:
        """读取 JSON 文件"""
        with open(path, 'r', encoding='utf-8') as file:
            return json.load(file)

    def get_paired_data(self) -> List[tuple]:
        """返回 (test_data, res_data) 的配对列表"""
        return [
            (
                self.read_json(self.file_paths["AndroidControl_test_data"]),
                self.read_json(self.file_paths["AndroidControl_res_data"])
            ),
            (
                self.read_json(self.file_paths["AITZ_test_data"]),
                self.read_json(self.file_paths["AITZ_res_data"])
            ),
            (
                self.read_json(self.file_paths["GUI_Odyssey_test_data"]),
                self.read_json(self.file_paths["GUI_Odyssey_res_data"])
            ),
        ]

    @staticmethod
    def get_visual_data(test_data: List[Dict], res_data: Dict, dataset_name: str):
        """根据 res_data 中的成功 action，筛选 test_data"""
        visual_new_data = [
            item for item in res_data['detailed_results']
            if item['action_type'] in [4, 5, 6, 8, 9] and item['is_success']
        ]
        image_path_list = [item['image_path'][0] for item in visual_new_data]
        step_list = [item['step_id'] for item in visual_new_data]

        mask_test_data = [
            {**item, "dataset_name": dataset_name}
            for item in test_data
            if item['images'][0] in image_path_list and item['step_id'] in step_list
        ]
        mask_res_data = [
            {**item, "dataset_name": dataset_name}
            for item in res_data['detailed_results']
            if item['image_path'][0] in image_path_list and item['step_id'] in step_list
        ]
        return mask_test_data, mask_res_data

    def process(self) -> List[Dict]:
        """处理所有配对，返回结果"""
        result1, result2 = [], []
        for idx, (test_data, res_data) in enumerate(self.get_paired_data()):
            mask_test_data, mask_res_data = self.get_visual_data(test_data, res_data, self.dataset_names[idx])
            result1.extend(mask_test_data)
            result2.extend(mask_res_data)
        return result1, result2

    def saveJson(self, result, output_path):
        with open(output_path, 'w', encoding='utf-8') as file:
            json.dump(result, file, ensure_ascii=False, indent=2)
        print(f"[INFO] Saved {len(result)} items to {output_path}")

    def save(self, output_path1: str, output_path2: str):
        """保存结果到指定路径"""
        result1, result2 = self.process()
        self.saveJson(result1, output_path1)
        self.saveJson(result2, output_path2)
        return result1, result2


class ReplaceDataProcessor:
  
    def UI_TARS(self, data: str, output_file: str, mode):
        for item in data:
            try:
                if mode == 'visual_shortcuts':
                    sentence = item['messages'][-1]['content'][0]['text'].split('\n')[0].split('Thought: ')[1]                
                    item['messages'][-1]['content'][0]['text'] = \
                        item['messages'][-1]['content'][0]['text'].replace(sentence, "[]")
            except Exception:
                continue
        self.save(output_file, data)
    
    def GUI_Owl(self, data: str, output_file: str, mode: str):
        for item in data:
            try:
                if mode == 'visual_shortcuts':
                    sentence = item['messages'][1]['content'][0]['text'].split('\n')[0].split('The user query: ')[1]
                    item['messages'][1]['content'][0]['text'] = \
                        item['messages'][1]['content'][0]['text'].replace(sentence, "[]")
            except Exception:
                continue
        self.save(output_file, data)

    def OS_Genesis(self, data: str, output_file: str, mode: str):
        for item in data:
            try:
                if mode == 'visual_shortcuts':
                    sentence = item['question'].split('\n')[-3].split('Low-level thought: ')[1]
          
                    item['question'] = \
                        item['question'].replace(sentence, "[]")
            except Exception:
                continue
        self.save(output_file, data)

    def Aguvis(self, data: str, output_file: str, mode: str):
        for item in data:
            try:
                if mode == 'visual_shortcuts':
                    sentence = item['messages']['content'][1]['text'].split("\n\n")[-1].split("low_level_instruction: ")[1]
                    item['messages']['content'][1]['text'] = \
                        item['messages']['content'][1]['text'].replace(sentence, '[]')
            except Exception:
                continue
        self.save(output_file, data)

    def Agent_CPM(self, data: str, output_file: str, mode: str):
        for item in data:
            try:
                if mode == 'visual_shortcuts':
                    sentence = item['messages'][0]['content'][0].split("<Question>")[-1].split("</Question>")[0]
                    
                    item['messages'][0]['content'][0] = \
                        item['messages'][0]['content'][0].replace(sentence, "[]")
            except Exception:
                continue
        self.save(output_file, data)
        

    def save(self, output_file, data):
        with open(output_file, 'w', encoding='utf-8') as outfile:
            json.dump(data, outfile, ensure_ascii=False, indent=2)

        print(f"[INFO] TextDataProcessor: saved {len(data)} items to {output_file}")


if __name__ == "__main__":
    model = "UI-TARS-7B-SFT"
    model_type = "UI_TARS"
    dataset_type = "Low"
    mode = 'action_shortcuts'

    # 1) 生成 visual 数据
    visual_processor = TextDataProcessor(model, dataset_type, file_paths)

    mask_test_data, mask_res_data = visual_processor.save(
        f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/{model}.json",
        f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/{model}_raw.json"
    )
    replacer = ReplaceDataProcessor()
    if model_type == 'UI_TARS':
        replacer.UI_TARS(mask_test_data, f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/"+f'{model}.json', mode)
    elif model_type == 'GUI_Owl':
        replacer.GUI_Owl(mask_test_data, f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/"+f'{model}.json', mode)
    elif model_type == 'OS_Genesis':
        replacer.OS_Genesis(mask_test_data, f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/"+f'{model}.json', mode)
    elif model_type == 'Aguvis':
        replacer.Aguvis(mask_test_data, f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/"+f'{model}.json', mode)
    elif model_type == 'Agent_CPM':
        replacer.Agent_CPM(mask_test_data, f"/Agent_ScanKit/datasets/json/structure_mask/{mode}/"+f'{model}.json', mode)

