import os, re, json, shutil, time, warnings
import sys
# 确保可导入 utils 与本目录模块
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.abspath(os.path.join(_THIS_DIR, '..', '..'))
if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)
import cv2
import uuid
from pathlib import Path
from typing import Dict, Any, List, Optional

import numpy as np
import torch
from ase.io import read, write
from ase.units import GPa
from pymatgen.core import Structure
from typing import Optional, List, Tuple

# ===== 请确认以下自定义模块可 import =====
from test_large_image_model_batch import inference_large_image_cv2_pil
from utils.fft_convert import load_gray_image, analyze_image
from structure_paired_reconstruction.atom_ai_analysis_fft import generate_candidate_windows
from structure_paired_reconstruction.utils_without_template import reconstruct_minimal_cell
from structure_paired_reconstruction.batch_structure_paired import match_one, load_metadata, extract_material_id
from structure_paired_reconstruction.batch_atoms_analysis import process_image, shrink_once
from structure_paired_reconstruction.justify_img_quality import imread_gray, assess_image
from mattersim.forcefield          import MatterSimCalculator
from mattersim.applications.relax  import Relaxer
# ========================================

warnings.filterwarnings("ignore", category=FutureWarning, module=r"mattersim")

# ---------- 辅助函数 ---------- #
_ELEMENT_PATTERN = re.compile(r'(?:elements?|元素)\s*[:=]\s*([A-Za-z,\s]+)', re.I)


def assess_image_path(
    path: str,
    min_periodicity: float = 0.40,
    min_snr: float = 2.0,
    min_cnr: float = 1.5,
) -> Dict[str, float | str]:
    """Assess a denoised STEM image from file path."""
    img = imread_gray(path)
    res = assess_image(img, min_periodicity=min_periodicity, min_snr=min_snr, min_cnr=min_cnr)
    res = {"image": str(path), **res}
    return res

def shrink_or_window(sup_cif, out_cif_path, MAX_NUM_ITER=4, MAX_ATOMS_NUM=50):
    cur = Structure.from_file(str(sup_cif))
    for step in range(1, MAX_NUM_ITER):
        if len(cur) <= MAX_ATOMS_NUM:
            print(f"≤{MAX_ATOMS_NUM} atoms，停止。")
            break
        nxt = shrink_once(cur)
        if nxt is None or len(nxt) >= len(cur):
            print("已无法进一步缩减。")
            break
        print(f"Step {step}: {len(cur)} → {len(nxt)} atoms")
        cur = nxt
    cur.to(filename=str(out_cif_path))
    
    
def parse_elements_from_text(text: str) -> Optional[List[str]]:
    m = _ELEMENT_PATTERN.search(text)
    if not m:
        return None
    return [e.strip().capitalize() for e in re.split(r'[,\s]+', m.group(1)) if e.strip()]


def span_min_axis(atoms):
    return int(np.argmin(atoms.positions.ptp(axis=0)))


def refine_top1(top_matches, user_elements, metadata):
    """
    若 `user_elements` 给定 → 用它们过滤；否则直接返回 top‑1
    """
    if not user_elements:
        return top_matches[0][0]

    target = set(user_elements)
    filtered = []
    for name, dist in top_matches:
        mid   = extract_material_id(name)
        elems = metadata.get(mid, set())
        if elems == target:
            filtered.append((name, dist))
    return (filtered or top_matches)[0][0]


def denoise_patch_inference_tool(image_path, weight_path, work_root, device="cuda"):
    """Step 1: Patch推理重建，返回重建图片路径"""
    try:
        img_p = Path(image_path).expanduser().resolve(strict=True)
        d_recon = Path(work_root).expanduser().resolve() / "01_recon"
        d_recon.mkdir(parents=True, exist_ok=True)
        recon_arr = inference_large_image_cv2_pil(
            str(img_p), weight_path,
            crop_size=128, stride=64, batch_size=32, device=device
        )
        recon_png = d_recon / f"{img_p.stem}_recon.png"
        cv2.imwrite(str(recon_png), recon_arr)
        return {"success": True, "recon_png": str(recon_png)}
    except Exception as e:
        return {"success": False, "error": str(e)}


def template_match_tool(recon_png, label_dir, metadata_csv, user_message, work_root):
    """Step 2: 模板匹配，返回最佳label路径和元素信息"""
    try:
        d_label = Path(work_root).expanduser().resolve() / "02_label"
        d_label.mkdir(parents=True, exist_ok=True)
        metadata = load_metadata(Path(metadata_csv))
        user_elems = parse_elements_from_text(user_message)
        top_matches = match_one(recon_png, Path(label_dir), topk=3, min_area=5, max_dist=None, bin_width=5.0)
        best_label_name = refine_top1(top_matches, user_elems, metadata)
        src_label = Path(label_dir) / best_label_name
        dst_label = d_label / best_label_name
        shutil.copy2(src_label, dst_label)
        # 若用户没给元素 → 用 label 的 material_id 去 CSV 查
        if not user_elems:
            mid = extract_material_id(best_label_name)
            user_elems = sorted(metadata.get(mid, []))
        return {"success": True, "label_path": str(dst_label), "elements": user_elems}
    except Exception as e:
        return {"success": False, "error": str(e)}


def stem2cif_tool(label_path, elements, work_root, max_atoms=50, max_shrink_iter=4):
    '''
    Step 3: 图像→CIF结构，返回重建的cif和最终shrink CIF路径
    '''
    try:
        d_cif = Path(work_root).expanduser().resolve() / "03_recon_cif"
        d_cif.mkdir(parents=True, exist_ok=True)
        atoms = process_image(label_path, elements)
        if atoms is None or len(atoms) == 0:
            return {"success": False, "error": "process_image failed to detect atoms!"}
        import re, uuid
        mid_match = re.search(r"(2dm-\d+)", Path(label_path).stem, re.I)
        mid = mid_match.group(1).lower() if mid_match else f"tmp-{uuid.uuid4().hex[:6]}"
        cif_super = d_cif / f"{mid}_reconstructed.cif"
        write(cif_super, atoms, format="cif", wrap=False)
        cif_final = d_cif / "output_final.cif"
        shrink_or_window(cif_super, cif_final, MAX_NUM_ITER=max_shrink_iter, MAX_ATOMS_NUM=max_atoms)
        # 只返回cif路径，不返回atoms对象
        return {"success": True, "cif_path": str(cif_final)}
    except Exception as e:
        return {"success": False, "error": str(e)}


def property_prediction_tool(cif_path, work_root, noise_amp=0.05, relax_steps=500, device="cuda"):
    """Step 4: 物性预测，返回能量、力、应力等"""
    try:
        d_relax = Path(work_root).expanduser().resolve() / "04_relax"
        d_relax.mkdir(parents=True, exist_ok=True)
        atoms_relax = read(cif_path)
        axis_min = span_min_axis(atoms_relax)
        import numpy as np
        atoms_relax.positions[:, axis_min] += noise_amp * np.random.randn(len(atoms_relax))
        atoms_relax.calc = MatterSimCalculator(load_path="MatterSim-v1.0.0-5M.pth", device=device)
        relaxer = Relaxer(optimizer="BFGS", filter=None, constrain_symmetry=False)
        converged, atoms_relaxed = relaxer.relax(atoms_relax, steps=relax_steps)
        cif_relaxed = d_relax / "relaxed.cif"
        write(cif_relaxed, atoms_relaxed)
        E = atoms_relaxed.get_potential_energy()
        F0 = atoms_relaxed.get_forces()[0]
        sxx = atoms_relaxed.get_stress(voigt=False)[0][0]
        from ase.units import GPa
        return {
            "success": True,
            "relaxed_cif": str(cif_relaxed),
            "energy_eV": float(E),
            "energy_per_atom": float(E/len(atoms_relaxed)),
            "force_first_atom": F0.tolist(),
            "stress_xx_GPa": float(sxx/GPa),
            "converged": bool(converged)
        }
    except Exception as e:
        return {"success": False, "error": str(e)}

def reconstruct_from_denoised_img(
    denoised_img: str,
    user_elements: List[str],
    coord: Optional[dict] = None,
    pixel_size: Optional[float] = 0.1,
    top_n: int = 3,
    out_dir: str = "pipeline_out",
    snap_hex: int = 1,
    auto_basis: int = 1,
    vec_tol: float = 0.02,
    vec_maxlen: float = 0.45,
    vec_cross_min: float = 0.0002,
    dedup_xy: float = 0.02,
    dedup_z: float = 0.05
) -> Tuple[str, dict, list]:
    """端到端：
    1) FFT 提取最小单胞参数 (a,b,alpha)
    2) 生成 top-N 候选窗口 CIF
    3) 使用前3个窗口在 tools_without_template 中重建最小单胞
    返回：(final_cif_path, new_cell_dict, basis_atoms)
    """
    os.makedirs(out_dir, exist_ok=True)

    # 1) FFT → lattice params
    img = load_gray_image(denoised_img)
    fft_out_dir = os.path.join(out_dir, 'fft')
    os.makedirs(fft_out_dir, exist_ok=True)
    fft_res = analyze_image(img, pixel_size=pixel_size, show=False, save=True, out_dir=fft_out_dir)
    lat_a = float(fft_res.a_physical) if fft_res.a_physical is not None else None
    lat_b = float(fft_res.b_physical) if fft_res.b_physical is not None else None
    lat_alpha = float(fft_res.alpha_deg)
    # 2D 基面：把 alpha 作为 gamma 传递（如用户提供其它信息可在下游覆盖）
    lat_gamma = lat_alpha

    # 2) 候选窗口 CIF（top N）
    cand_prefix = os.path.join(out_dir, 'candidate_window')
    cand_paths = generate_candidate_windows(
        image_path=denoised_img,
        elements=user_elements,
        coord=coord,
        lattice_params=(lat_a if lat_a else 3.0, lat_b if lat_b else 3.0, lat_gamma),
        top_n=top_n,
        out_prefix=cand_prefix
    )

    if not cand_paths:
        raise RuntimeError("未能生成候选窗口 CIF，无法继续最小单胞重建")

    # 至少取前 3 个（若不足按实际数量填充第一个）
    win1 = cand_paths[0]
    win2 = cand_paths[1] if len(cand_paths) >= 2 else cand_paths[0]
    win3 = cand_paths[2] if len(cand_paths) >= 3 else cand_paths[0]

    # 3) 最小单胞重建
    out_minimal = os.path.join(out_dir, 'minimal_cell_from_top3_pipeline.cif')
    basis_spec = None if coord is None else ", ".join(f"{k}:{v}" for k,v in coord.items())

    final_path, new_cell, basis_atoms = reconstruct_minimal_cell(
        win1=win1,
        win2=win2,
        win3=win3,
        vec_tol=vec_tol,
        vec_maxlen=vec_maxlen,
        vec_cross_min=vec_cross_min,
        dedup_xy=dedup_xy,
        dedup_z=dedup_z,
        snap_hex=snap_hex,
        basis=basis_spec,
        out=out_minimal,
        auto_basis=auto_basis,
        lat_a=lat_a,
        lat_b=lat_b,
        lat_gamma=lat_gamma
    )

    return final_path, new_cell, basis_atoms

