#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
从结果JSON文件生成statistics_file（类似evaluate_test.py生成的格式）
"""

import os
import json
import argparse
from collections import defaultdict
from operator import itemgetter
from itertools import groupby


def get_action_type_mapping():
    """
    获取动作类型映射（基于结果文件使用的动作空间）
    根据 result_preprocess.py 中的映射
    """
    return {
        1: "CLICK",
        2: "LONG_CLICK", 
        3: "TYPE",
        4: "SCROLL",
        5: "PRESS_BACK",
        6: "PRESS_HOME",
        7: "OPENAPP",
        8: "WAIT",
        9: "COMPLETE",
        10: "ENTER",
        11: "PRESS_APPSELECT",
        12: "INCOMPLETE",
        13: "MOVETO",
        14: "PRESS_SPACE",
    }


def compute_single_metrics(results, type_to_name):
    """计算单个结果集的评估指标"""
    # 初始化计数器
    action_counts = {name: 0 for name in type_to_name.values()}
    action_counts["TOTAL"] = 0
    
    action_success = {name: 0 for name in type_to_name.values()}
    action_success["TOTAL"] = 0
    
    action_type_match = {name: 0 for name in type_to_name.values()}
    action_type_match["TOTAL"] = 0
    
    episode_count = 0
    success_episode_count = 0
    
    # 按 episode 分组（如果有episode_id字段）
    has_episode_id = any("episode_id" in r for r in results) if results else False
    
    if has_episode_id:
        results.sort(key=itemgetter("episode_id", "step_id") if "step_id" in results[0] else itemgetter("episode_id"))
        trajs = {k: list(v) for k, v in groupby(results, key=itemgetter("episode_id"))}
        
        for episode_id, episode_records in trajs.items():
            episode_count += 1
            episode_success = True
            
            for record in episode_records:
                is_success = record.get("is_success", False)
                is_type_match = record.get("is_type_match", False)
                gt_type = record.get("action_type", 0)  # 使用 action_type 作为 ground truth 类型
                
                # 获取动作名称（如果类型为 0 或不在映射中，跳过）
                if gt_type in type_to_name:
                    action_name = type_to_name[gt_type]
                    action_counts[action_name] += 1
                    if is_success:
                        action_success[action_name] += 1
                    if is_type_match:
                        action_type_match[action_name] += 1
                
                # 统计总数（包括所有类型，包括未知类型）
                action_counts["TOTAL"] += 1
                if is_success:
                    action_success["TOTAL"] += 1
                if is_type_match:
                    action_type_match["TOTAL"] += 1
                
                if not is_success:
                    episode_success = False
            
            if episode_success:
                success_episode_count += 1
    else:
        # 没有episode_id，按样本统计
        for record in results:
            is_success = record.get("is_success", False)
            is_type_match = record.get("is_type_match", False)
            gt_type = record.get("action_type", 0)  # 使用 action_type 作为 ground truth 类型
            
            # 获取动作名称（如果类型为 0 或不在映射中，跳过）
            if gt_type in type_to_name:
                action_name = type_to_name[gt_type]
                action_counts[action_name] += 1
                if is_success:
                    action_success[action_name] += 1
                if is_type_match:
                    action_type_match[action_name] += 1
            
            # 统计总数（包括所有类型，包括未知类型）
            action_counts["TOTAL"] += 1
            if is_success:
                action_success["TOTAL"] += 1
            if is_type_match:
                action_type_match["TOTAL"] += 1
    
    # 计算 TMR 和 AMR
    metrics = {}
    all_action_names = list(type_to_name.values()) + ["TOTAL"]
    
    for action_name in all_action_names:
        count = action_counts[action_name]
        tmr = action_type_match[action_name] / count if count > 0 else 0
        amr = action_success[action_name] / count if count > 0 else 0
        
        metrics[action_name] = {
            "count": count,
            "type_match": action_type_match[action_name],
            "action_match": action_success[action_name],
            "TMR": round(tmr, 4),
            "AMR": round(amr, 4)
        }
    
    # Episode 成功率
    if has_episode_id:
        episode_success_rate = success_episode_count / episode_count if episode_count > 0 else 0
        metrics["episode_success_rate"] = round(episode_success_rate, 4)
        metrics["episode_count"] = episode_count
        metrics["success_episode_count"] = success_episode_count
    
    return metrics


def compute_metrics(all_results):
    """计算评估指标，包括按 data_type 分类统计（如果有）"""
    # 动作类型映射
    type_to_name = get_action_type_mapping()
    
    # 检查是否有 data_type 字段
    has_datatype = any("data_type" in r for r in all_results) if all_results else False
    
    if has_datatype:
        # 按 data_type 分组统计
        results_by_datatype = defaultdict(list)
        for item in all_results:
            data_type = item.get("data_type", -1)  # 如果没有 data_type，使用 -1
            results_by_datatype[data_type].append(item)
        
        # 计算每个 data_type 的统计
        metrics_by_datatype = {}
        for data_type in sorted(results_by_datatype.keys()):
            type_results = results_by_datatype[data_type]
            type_metrics = compute_single_metrics(type_results, type_to_name)
            metrics_by_datatype[f"data_type_{data_type}"] = type_metrics
        
        # 整体统计
        overall_metrics = compute_single_metrics(all_results, type_to_name)
        
        return {
            "overall": overall_metrics,
            "by_datatype": metrics_by_datatype
        }
    else:
        # 没有 data_type，只计算整体统计
        overall_metrics = compute_single_metrics(all_results, type_to_name)
        return {
            "overall": overall_metrics,
            "by_datatype": {}
        }


def load_data_type_mapping(data_dir):
    """
    从原始数据文件加载 data_type 映射
    
    Args:
        data_dir: 原始数据目录路径
    
    Returns:
        dict: {(episode_id, step_id): data_type} 的映射
    """
    mapping = {}
    
    # 查找所有原始数据文件
    data_files = []
    if os.path.isdir(data_dir):
        for filename in os.listdir(data_dir):
            if filename.endswith('.json'):
                data_files.append(os.path.join(data_dir, filename))
    elif os.path.isfile(data_dir):
        data_files = [data_dir]
    
    print(f"正在从原始数据文件加载 data_type 映射...")
    for data_file in data_files:
        print(f"  读取: {data_file}")
        try:
            with open(data_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            if isinstance(data, list):
                for item in data:
                    episode_id = str(item.get("episode_id", ""))
                    step_id = item.get("step_id", 0)
                    data_type = item.get("data_type", -1)
                    mapping[(episode_id, step_id)] = data_type
            elif isinstance(data, dict):
                # 如果是字典格式，尝试找到数据列表
                for key in ["data", "detailed_results", "results"]:
                    if key in data and isinstance(data[key], list):
                        for item in data[key]:
                            episode_id = str(item.get("episode_id", ""))
                            step_id = item.get("step_id", 0)
                            data_type = item.get("data_type", -1)
                            mapping[(episode_id, step_id)] = data_type
                        break
        except Exception as e:
            print(f"  警告: 读取文件失败 {data_file}: {e}")
    
    print(f"  已加载 {len(mapping)} 条 data_type 映射")
    return mapping


def generate_statistics_file(result_json_path, output_dir=None, data_dir=None):
    """
    从结果JSON文件生成statistics_file
    
    Args:
        result_json_path: 结果JSON文件路径
        output_dir: 输出目录（如果为None，则使用结果文件所在目录）
        data_dir: 原始数据目录路径（用于获取 data_type）
    """
    print(f"正在读取结果文件: {result_json_path}")
    
    # 读取结果文件
    with open(result_json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 提取 detailed_results
    if "detailed_results" in data:
        all_results = data["detailed_results"]
    elif isinstance(data, list):
        all_results = data
    else:
        raise ValueError(f"无法识别结果文件格式: {result_json_path}")
    
    print(f"总共有 {len(all_results)} 条结果")
    
    # 从原始数据加载 data_type 映射
    if data_dir:
        data_type_mapping = load_data_type_mapping(data_dir)
        
        # 为结果添加 data_type 字段
        matched_count = 0
        for result in all_results:
            episode_id = str(result.get("episode_id", ""))
            step_id = result.get("step_id", 0)
            key = (episode_id, step_id)
            if key in data_type_mapping:
                result["data_type"] = data_type_mapping[key]
                matched_count += 1
            else:
                result["data_type"] = -1  # 未匹配到的使用 -1
        
        print(f"  匹配到 {matched_count}/{len(all_results)} 条结果的 data_type")
    else:
        print("  未提供原始数据目录，跳过 data_type 匹配")
    
    # 计算指标
    print("正在计算评估指标...")
    metrics = compute_metrics(all_results)
    
    # 构建 statistics_output
    statistics_output = {}
    
    # 先添加按 data_type 分类的统计
    for data_type_key, type_metrics in metrics["by_datatype"].items():
        statistics_output[data_type_key] = {
            "metrics": {k: v for k, v in type_metrics.items() if k not in ["episode_success_rate", "episode_count", "success_episode_count"]},
            "episode_success_rate": type_metrics.get("episode_success_rate", 0),
            "episode_count": type_metrics.get("episode_count", 0),
            "success_episode_count": type_metrics.get("success_episode_count", 0)
        }
    
    # 最后添加整体统计
    statistics_output["overall"] = {
        "metrics": {k: v for k, v in metrics["overall"].items() if k not in ["episode_success_rate", "episode_count", "success_episode_count"]},
        "episode_success_rate": metrics["overall"].get("episode_success_rate", 0),
        "episode_count": metrics["overall"].get("episode_count", 0),
        "success_episode_count": metrics["overall"].get("success_episode_count", 0)
    }
    
    # 确定输出文件路径
    if output_dir is None:
        output_dir = os.path.dirname(result_json_path)
    
    os.makedirs(output_dir, exist_ok=True)
    
    # 生成输出文件名（添加 _statistics_by_datatype 后缀）
    base_name = os.path.basename(result_json_path)
    if base_name.endswith('.json'):
        base_name = base_name[:-5]
    output_file = os.path.join(output_dir, f"{base_name}_statistics_by_datatype.json")
    
    # 保存文件
    print(f"正在保存统计文件: {output_file}")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(statistics_output, f, ensure_ascii=False, indent=2)
    
    print(f"统计文件已保存: {output_file}")
    
    # 打印简要统计
    print("\n" + "=" * 60)
    print("统计结果摘要")
    print("=" * 60)
    overall = statistics_output["overall"]
    total_metric = overall["metrics"].get("TOTAL", {})
    if total_metric.get("count", 0) > 0:
        print(f"总动作数: {total_metric['count']}")
        print(f"类型匹配率 (TMR): {total_metric['TMR']:.4f}")
        print(f"动作匹配率 (AMR): {total_metric['AMR']:.4f}")
    if "episode_success_rate" in overall:
        print(f"Episode成功率: {overall['episode_success_rate']:.4f}")
        print(f"Episode总数: {overall['episode_count']}")
        print(f"成功Episode数: {overall['success_episode_count']}")
    
    return output_file


def main():
    parser = argparse.ArgumentParser(description='从结果JSON文件生成statistics_file')
    parser.add_argument('--result_files', type=str, nargs='+', required=True,
                       help='结果JSON文件路径（可以指定多个）')
    parser.add_argument('--output_dir', type=str, default=None,
                       help='输出目录（如果为None，则使用结果文件所在目录）')
    parser.add_argument('--data_dir', type=str, default=None,
                       help='原始数据目录路径（用于获取 data_type，如果为None则从结果文件所在目录的上级目录的data文件夹查找）')
    
    args = parser.parse_args()
    
    # 如果没有指定 data_dir，尝试自动查找
    if args.data_dir is None:
        # 尝试从结果文件所在目录的上级目录查找 data 文件夹
        if args.result_files:
            result_dir = os.path.dirname(os.path.abspath(args.result_files[0]))
            parent_dir = os.path.dirname(result_dir)
            potential_data_dir = os.path.join(parent_dir, 'data')
            if os.path.exists(potential_data_dir):
                args.data_dir = potential_data_dir
                print(f"自动检测到原始数据目录: {args.data_dir}")
    
    for result_file in args.result_files:
        if not os.path.exists(result_file):
            print(f"警告: 文件不存在，跳过: {result_file}")
            continue
        
        try:
            generate_statistics_file(result_file, args.output_dir, args.data_dir)
            print()
        except Exception as e:
            print(f"错误: 处理文件失败 {result_file}: {e}")
            import traceback
            traceback.print_exc()
            print()


if __name__ == "__main__":
    main()
