import os
import json
import argparse
import json5
import re
import time
from pathlib import Path
from typing import Any, Dict
from qwen_agent.agents import Assistant
from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.output_beautify import typewriter_print
from tools_module_wrap import (
    denoise_patch_inference_tool,
    assess_image_path,
    template_match_tool,
    stem2cif_tool,
    reconstruct_from_denoised_img,
    property_prediction_tool
)
import warnings
from scipy.optimize import OptimizeWarning

warnings.filterwarnings("ignore", category=OptimizeWarning)

# -----------------------------------------------------------------------
# 全局默认（可被 main 覆盖）
# -----------------------------------------------------------------------
_PIPELINE_DEFAULTS: Dict[str, str] = {
    "weight_path": None,
    "label_dir": None,
    "metadata_csv": None,
}


# ========== 工具注册 ==========
@register_tool('denoise_patch_inference_tool')
class DenoisePatchTool(BaseTool):
    description = '对STEM大图进行去噪和patch重建，返回重建图片路径'
    parameters = [
        {'name': 'image_path', 'type': 'string', 'description': 'STEM大图路径', 'required': True},
        {'name': 'weight_path', 'type': 'string', 'description': '去噪模型权重路径', 'required': True},
        {'name': 'work_root', 'type': 'string', 'description': '工作目录', 'required': True},
        {'name': 'device', 'type': 'string', 'description': '推理设备', 'required': False, 'default': 'cuda'}
    ]

    def call(self, params: str, **kwargs) -> str:
        args = json5.loads(params)
        result = denoise_patch_inference_tool(**args)
        return json5.dumps(result, ensure_ascii=False)


@register_tool('template_match_tool')
class TemplateMatchTool(BaseTool):
    description = '对去噪后图片做模板匹配，返回最佳label路径和元素信息'
    parameters = [
        {'name': 'recon_png', 'type': 'string', 'description': '去噪后图片路径', 'required': True},
        {'name': 'label_dir', 'type': 'string', 'description': '模板匹配label目录', 'required': True},
        {'name': 'metadata_csv', 'type': 'string', 'description': '材料元素元数据CSV', 'required': True},
        {'name': 'user_message', 'type': 'string', 'description': '用户补充说明', 'required': True},
        {'name': 'work_root', 'type': 'string', 'description': '工作目录', 'required': True}
    ]

    def call(self, params: str, **kwargs) -> str:
        args = json5.loads(params)
        result = template_match_tool(**args)
        return json5.dumps(result, ensure_ascii=False)


@register_tool('stem2cif_tool')
class Stem2CifTool(BaseTool):
    description = '将label图片和元素类型转换为CIF结构，返回CIF路径'
    parameters = [
        {'name': 'label_path', 'type': 'string', 'description': 'label图片路径', 'required': True},
        {'name': 'elements', 'type': 'array', 'description': '元素类型列表', 'required': True},
        {'name': 'work_root', 'type': 'string', 'description': '工作目录', 'required': True},
        {'name': 'max_atoms', 'type': 'integer', 'description': '最大原子数', 'required': False, 'default': 50},
        {'name': 'max_shrink_iter', 'type': 'integer', 'description': '最大缩减迭代次数', 'required': False, 'default': 4}
    ]

    def call(self, params: str, **kwargs) -> str:
        args = json5.loads(params)
        result = stem2cif_tool(**args)
        return json5.dumps(result, ensure_ascii=False)


@register_tool('property_prediction_tool')
class PropertyPredictionTool(BaseTool):
    description = '对CIF结构进行物性预测，返回能量、力、应力等'
    parameters = [
        {'name': 'cif_path', 'type': 'string', 'description': 'CIF结构路径', 'required': True},
        {'name': 'work_root', 'type': 'string', 'description': '工作目录', 'required': True},
        {'name': 'noise_amp', 'type': 'number', 'description': '扰动幅度', 'required': False, 'default': 0.05},
        {'name': 'relax_steps', 'type': 'integer', 'description': '松弛步数', 'required': False, 'default': 500},
        {'name': 'device', 'type': 'string', 'description': '推理设备', 'required': False, 'default': 'cuda'}
    ]

    def call(self, params: str, **kwargs) -> str:
        args = json5.loads(params)
        result = property_prediction_tool(**args)
        return json5.dumps(result, ensure_ascii=False)


@register_tool('direct_reconstruct_tool')
class DirectReconstructTool(BaseTool):
    description = '绕过模板匹配，使用FFT→候选窗口→最小单胞重建的直通管线；输入去噪后图像+元素(+可选配位)'
    parameters = [
        {'name': 'denoised_img', 'type': 'string', 'description': '去噪后的STEM图像路径', 'required': True},
        {'name': 'elements', 'type': 'array', 'description': '元素列表，如 ["Zr","N","Cl"]', 'required': True},
        {'name': 'coord', 'type': 'object', 'description': '可选配位关系，如 {"Zr":1,"N":1,"Cl":1}', 'required': False},
        {'name': 'pixel_size', 'type': 'number', 'description': '像素尺寸(Å/px)，影响FFT物理单位', 'required': False, 'default': 0.10},
        {'name': 'top_n', 'type': 'integer', 'description': '候选窗口数量(top-N)', 'required': False, 'default': 3},
        {'name': 'out_dir', 'type': 'string', 'description': '输出目录', 'required': False, 'default': 'pipeline_out'}
    ]

    def call(self, params: str, **kwargs) -> str:
        args = json5.loads(params)
        denoised_img = args['denoised_img']
        elements = args['elements']
        coord = args.get('coord')
        pixel_size = args.get('pixel_size', 0.10)
        top_n = int(args.get('top_n', 3))
        out_dir = args.get('out_dir', 'pipeline_out')
        final_cif, cell, basis_atoms = reconstruct_from_denoised_img(
            denoised_img=denoised_img,
            user_elements=elements,
            coord=coord,
            pixel_size=pixel_size,
            top_n=top_n,
            out_dir=out_dir
        )
        result = {
            'final_cif': final_cif,
            'cell': cell,
            'basis_atoms': basis_atoms
        }
        return json5.dumps(result, ensure_ascii=False)


@register_tool('assess_image_path_tool')
class AssessImagePathTool(BaseTool):
    description = '评估图像路径的有效性和图像质量，返回评估结果'
    parameters = [
        {'name': 'path', 'type': 'string', 'description': '图像路径', 'required': True},
    ]

    def call(self, params: str, **kwargs) -> str:
        args = json5.loads(params)
        result = assess_image_path(**args)
        return json5.dumps(result, ensure_ascii=False)


# ========== 统计功能增加 ==========
class ImageProcessingStats:
    """统计每张图像处理的调用次数、成功次数和失败重试次数"""

    def __init__(self):
        self.call_count = 0
        self.success_count = 0
        self.retry_count = 0

    def increment_call(self):
        """增加调用次数"""
        self.call_count += 1

    def increment_success(self):
        """增加成功次数"""
        self.success_count += 1

    def increment_retry(self):
        """增加失败重试次数"""
        self.retry_count += 1

    def get_stats(self):
        """返回统计结果"""
        return {
            'call_count': self.call_count,
            'success_count': self.success_count,
            'retry_count': self.retry_count,
            'success_rate': self.success_count / self.call_count if self.call_count > 0 else 0
        }


# def extract_id_from_path(path):
#     import re
#     match = re.search(r'2dm-\d+', path)
#     if match:
#         return match.group(0)
#     return os.path.splitext(os.path.basename(path))[0]


def extract_id_from_path(path: str) -> str:
    """
    从文件路径中提取唯一 ID：
    - 去掉文件扩展名
    - 去掉最后的 `_iDPC*` 及其后缀（如 _iDPC_V3）
    适用于形如
    orthogonal_2dm-2994_supercell_16x16x1_dose50000_sampling0.1_iDPC_V3.png
    """
    basename = os.path.basename(path)                 # orthogonal_2dm-2994_supercell_16x16x1_dose50000_sampling0.1_iDPC_V3.png
    name_no_ext = os.path.splitext(basename)[0]       # orthogonal_2dm-2994_supercell_16x16x1_dose50000_sampling0.1_iDPC_V3
    # 去掉 `_iDPC...` 及其后缀
    clean_name = re.sub(r'_iDPC.*$', '', name_no_ext) # orthogonal_2dm-2994_supercell_16x16x1_dose50000_sampling0.1
    return clean_name


def batch_agent_process(args):
    import csv
    import time
    # 读取json
    with open(args.json_path, 'r') as f:
        dataset = json.load(f)

    results_csv = args.results_csv
    csv_fields = ['id', 'img_path', 'cif_path', 'elements', 'energy', 'energy_per_atom']
    write_header = not os.path.exists(results_csv)
    with open(results_csv, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=csv_fields)
        if write_header:
            writer.writeheader()
        for idx, sample in enumerate(dataset):
            unique_id = extract_id_from_path(sample['img_path'])
            work_dir = os.path.join(args.work_root, unique_id)
            os.makedirs(work_dir, exist_ok=True)
            # 自动生成用户输入
            user_input = f"{sample['img_path']} {work_dir} 元素: {','.join(sample['elements'])}，**不得引用或检索任何外部数据库、文献直接查找答案**"
            messages = [
                {'role': 'user', 'content': user_input}
            ]
            agent_outputs = []
            tool_calls = []
            response_plain_text = ''
            start_time = time.time()
            stats = ImageProcessingStats()  # 统计每个图像的处理信息
            try:
                for response in args.bot.run(messages=messages):
                    agent_outputs.append(response)
                    # 记录工具调用轨迹
                    if isinstance(response, dict) and 'tool_calls' in response:
                        tool_calls.extend(response['tool_calls'])
                    # 拼接最终回复
                    response_plain_text = typewriter_print(response, response_plain_text)
                    # 更新统计信息
                    stats.increment_call()
                    if "error" in response:
                        stats.increment_retry()
                    else:
                        stats.increment_success()
                # 保存agent输出
                with open(os.path.join(work_dir, 'agent_output.txt'), 'w', encoding='utf-8') as f:
                    f.write(response_plain_text)
                with open(os.path.join(work_dir, f"{args.model}_agent_tools_info.json"), 'w', encoding='utf-8') as f:
                    json.dump({'tool_calls': tool_calls, 'all_outputs': agent_outputs, 'user_input': user_input, 'start_time': start_time, 'end_time': time.time()}, f, indent=2, ensure_ascii=False)
                # 从agent输出中提取energy等
                energy = None
                energy_per_atom = None
                m1 = re.search(r'energy_eV[_ ]*[:=]?\s*([-\d\.eE]+)', response_plain_text)
                if m1:
                    energy = m1.group(1)
                m2 = re.search(r'energy_per_atom[_ ]*[:=]?\s*([-\d\.eE]+)', response_plain_text)
                if m2:
                    energy_per_atom = m2.group(1)
                row = {
                    'id': unique_id,
                    'img_path': sample['img_path'],
                    'cif_path': sample.get('cif_path'),
                    'elements': ','.join(sample['elements']),
                    'energy': energy,
                    'energy_per_atom': energy_per_atom,
                }
                writer.writerow(row)

                # 记录每个图像的处理统计信息
                print(f"Image stats for {unique_id}: {stats.get_stats()}")
            except Exception as e:
                with open(os.path.join(work_dir, 'error.txt'), 'w') as f:
                    f.write(str(e))


# ========== 主流程 ==========

def main():
    parser = argparse.ArgumentParser(description="Qwen3 多轮对话Agent（Qwen-Agent范式）")
    parser.add_argument('--model', type=str, default='qwen-plus-2025-04-28')
    parser.add_argument('--model_server', type=str, default='http://localhost:8000/v1')
    parser.add_argument('--api_key', type=str, default=None)
    parser.add_argument('--weight_path', type=str, default=_PIPELINE_DEFAULTS["weight_path"])
    parser.add_argument('--label_dir', type=str, default=_PIPELINE_DEFAULTS["label_dir"])
    parser.add_argument('--metadata_csv', type=str, default=_PIPELINE_DEFAULTS["metadata_csv"])
    parser.add_argument('--work_root', type=str, default=None, help="一次性 pipeline 的结果保存目录")
    parser.add_argument('--batch', action='store_true', default=True, help='批量处理模式')
    parser.add_argument('--json_path', type=str, default=None, help='批量json路径')
    parser.add_argument('--results_csv', type=str, default='results_qwen3.csv', help='批量结果CSV输出')
    args = parser.parse_args()

    if args.batch:
        # 初始化agent
        llm_cfg = {
            'model': args.model,
            'api_key': args.api_key,
            'generate_cfg': {
                'top_p': 0.8,
                'enable_thinking': False,
            }
        }
        system_instruction = (
            "你是材料科学智能Agent，能够对STEM表征图像进行结构重建与物性分析。"
            "你可以调用如下工具,包括图像评估、去噪、模板匹配、结构重建、物性预测等工具。"
            "当用户提供去噪后图像与元素信息时，你可以直接根据情况选择不同方式进行重建；"
            "若需要先验或验证，也可选择模板匹配以辅助确定元素与结构先验；若无先验结构候选或者重建后的质量足够好时，也可直接进行最小单胞结构重建。"
            "每次用户输入后，你应根据需求自主选择是否以及如何组合调用工具，"
            "并在需要时自动选择合适的工具和参数。"
            "允许多轮对话，支持用户补充说明、追问、结果解释等。"
            f"全局参数：weight_path={args.weight_path}，label_dir={args.label_dir}，metadata_csv={args.metadata_csv}"
        )
        tools = [
            'denoise_patch_inference_tool',
            'assess_image_path_tool',
            'template_match_tool',
            'stem2cif_tool',
            'property_prediction_tool',
            'direct_reconstruct_tool'
        ]
        bot = Assistant(
            llm=llm_cfg,
            system_message=system_instruction,
            function_list=tools
        )
        args.bot = bot
        batch_agent_process(args)
        print('批量agent自动推理处理完成！')
        return


if __name__ == "__main__":
    main()
