import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch_geometric.nn import MessagePassing, radius_graph
from torch_scatter import scatter
import re
from typing import Dict, List, Tuple, Optional, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from dataclasses import dataclass
import math


def disabled_train(self, mode=True):
    return self


@dataclass
class PhysicsConstraints:
    max_bond_force: float = 5.0
    max_vdw_force: float = 2.0
    min_bond_length: float = 0.8
    force_convergence_threshold: float = 1e-4
    coordinate_convergence_threshold: float = 1e-3


class RadialBasisFunction(nn.Module):
    def __init__(self, num_rbf: int, cutoff: float, start: float = 0.0):
        super().__init__()
        self.num_rbf = num_rbf
        self.cutoff = cutoff
        self.start = start

        self.register_buffer('centers', torch.linspace(start, cutoff, num_rbf))
        self.register_buffer('widths', torch.tensor((cutoff - start) / num_rbf))

    def forward(self, distances):
        distances = distances.unsqueeze(-1)
        return torch.exp(-self.widths * (distances - self.centers) ** 2)


class PaiNNInteraction(nn.Module):
    def __init__(self, hidden_channels: int, num_rbf: int):
        super().__init__()
        self.hidden_channels = hidden_channels

        self.interatomic_context_net = nn.ModuleDict({
            'phi_rbf': nn.Sequential(
                nn.Linear(num_rbf, hidden_channels),
                nn.SiLU(),
                nn.Linear(hidden_channels, 3 * hidden_channels)
            )
        })

        self.intraatomic_context_net = nn.ModuleDict({
            'phi_m': nn.Sequential(
                nn.Linear(2 * hidden_channels, hidden_channels),
                nn.SiLU(),
                nn.Linear(hidden_channels, 3 * hidden_channels)
            )
        })

    def forward(self, s, v, edge_index, edge_rbf, edge_vector):
        N = s.size(0)
        row, col = edge_index

        phi_rbf_out = self.interatomic_context_net['phi_rbf'](edge_rbf)
        W_s, W_v_lin, W_v_quad = torch.split(phi_rbf_out, self.hidden_channels, dim=-1)

        s_j = s[col]
        v_j = v[col]  # (E, 3, H)

        message_s = W_s * s_j

        v_j_magnitude = torch.norm(v_j, dim=1, keepdim=True)  # (E, 1, H)
        message_v = W_v_lin.unsqueeze(1) * v_j + W_v_quad.unsqueeze(1) * v_j_magnitude * edge_vector.unsqueeze(-1)

        aggr_message_s = scatter(message_s, row, dim=0, dim_size=N, reduce='add')
        aggr_message_v = scatter(message_v, row, dim=0, dim_size=N, reduce='add')  # (N, 3, H)

        v_magnitude = torch.norm(v, dim=1)  # (N, H)
        s_cat = torch.cat([s, v_magnitude], dim=-1)  # (N, 2H)
        phi_m_out = self.intraatomic_context_net['phi_m'](s_cat)
        a_ss, a_sv, a_vv = torch.split(phi_m_out, self.hidden_channels, dim=-1)

        # (N, H) and (N, 3, H)
        delta_s = a_ss * aggr_message_s + torch.sum(a_sv.unsqueeze(1) * aggr_message_v, dim=1)
        delta_v = a_vv.unsqueeze(1) * aggr_message_v + a_sv.unsqueeze(1) * v

        return s + delta_s, v + delta_v


class PaiNNUpdate(nn.Module):
    def __init__(self, hidden_channels: int):
        super().__init__()
        self.hidden_channels = hidden_channels

        self.U = nn.Linear(hidden_channels, hidden_channels, bias=False)
        self.V = nn.Linear(hidden_channels, hidden_channels, bias=False)
        # 明确对向量通道的线性变换，避免原实现广播错误
        self.Wv = nn.Linear(hidden_channels, hidden_channels, bias=False)

    def forward(self, s, v):
        # v: (N, 3, H)
        v_squared_norm = torch.sum(v ** 2, dim=1)  # (N, H)

        U_v = self.U(v_squared_norm)  # (N, H)
        V_s = self.V(s)               # (N, H)

        s_update = U_v * torch.tanh(V_s)

        v_flat = v.reshape(-1, self.hidden_channels)        # (N*3, H)
        v_trans = self.Wv(v_flat).reshape(v.shape)          # (N,3,H)
        v_gate = torch.sigmoid(V_s).unsqueeze(1)            # (N,1,H)
        v_update = v_trans * v_gate

        return s + s_update, v + v_update


class PaiNNGeometricEncoder(nn.Module):
    def __init__(
        self,
        hidden_channels: int = 256,
        num_layers: int = 6,
        num_rbf: int = 20,
        cutoff: float = 10.0,
        max_z: int = 100,
    ):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.cutoff = cutoff

        self.atom_embedding = nn.Embedding(max_z, hidden_channels)
        self.rbf = RadialBasisFunction(num_rbf, cutoff)

        self.interactions = nn.ModuleList([
            PaiNNInteraction(hidden_channels, num_rbf) for _ in range(num_layers)
        ])
        self.updates = nn.ModuleList([
            PaiNNUpdate(hidden_channels) for _ in range(num_layers)
        ])

        self.readout = nn.ModuleDict({
            'scalar_out': nn.Sequential(
                nn.Linear(hidden_channels, hidden_channels // 2),
                nn.SiLU(),
                nn.Linear(hidden_channels // 2, 1)
            ),
            'vector_out': nn.Linear(hidden_channels, 1, bias=False)
        })

        # 两路独立摘要：化学与几何（解耦）
        ctx_in_dim = hidden_channels + 3  # pooled s + pooled v(3)
        self.chem_head = nn.Sequential(
            nn.Linear(ctx_in_dim, hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, 128)
        )
        self.geom_head = nn.Sequential(
            nn.Linear(ctx_in_dim, hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, 128)
        )

    def forward(self, z, pos, edge_index=None, batch=None):
        if edge_index is None:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)

        row, col = edge_index
        edge_vector = pos[col] - pos[row]
        edge_length = torch.norm(edge_vector, dim=1, keepdim=True)
        edge_vector = edge_vector / (edge_length + 1e-8)
        edge_rbf = self.rbf(edge_length.squeeze())

        s = self.atom_embedding(z)
        v = torch.zeros(z.size(0), 3, self.hidden_channels, device=z.device, dtype=s.dtype)

        for interaction, update in zip(self.interactions, self.updates):
            s, v = interaction(s, v, edge_index, edge_rbf, edge_vector)
            s, v = update(s, v)

        scalar_features = self.readout['scalar_out'](s)                     # (N,1)
        vector_features = self.readout['vector_out'](v).squeeze(-1)         # (N,3)

        if batch is not None:
            global_scalar = scatter(scalar_features.squeeze(), batch, dim=0, reduce='mean')  # (B,)
            global_vector = scatter(vector_features, batch, dim=0, reduce='mean')            # (B,3)
        else:
            global_scalar = scalar_features.mean(dim=0, keepdim=True).squeeze(0)             # (1,) -> ()
            global_vector = vector_features.mean(dim=0, keepdim=True)                         # (1,3)

        if batch is not None:
            pooled_s = scatter(s, batch, dim=0, reduce='mean')                   # (B,H)
            pooled_v3 = scatter(vector_features, batch, dim=0, reduce='mean')    # (B,3)
        else:
            pooled_s = s.mean(dim=0, keepdim=True)                               # (1,H)
            pooled_v3 = global_vector                                            # (1,3)

        ctx_input = torch.cat([pooled_s, pooled_v3], dim=-1)                     # (B, H+3) or (1,H+3)
        chemical_context = self.chem_head(ctx_input).squeeze(0)                  # (B,128) or (128,)
        geometric_context = self.geom_head(ctx_input).squeeze(0)                 # (B,128) or (128,)

        # 返回节点级表示以备需要
        return chemical_context, geometric_context, s, v


class SymbolicForceParser:
    def __init__(self):
        self.force_patterns = {
            'BOND_STRETCH': r'\[BOND_STRETCH\]\s+(\w+)-(\w+):\s+magnitude=([0-9.]+),\s+direction=(\w+),\s+force_vector=(\w+)<([^>]+)>\s+(\w+)<([^>]+)>',
            'ANGLE_BEND': r'\[ANGLE_BEND\]\s+(\w+)-(\w+)-(\w+):\s+magnitude=([0-9.]+),\s+direction=(\w+)',
            'VDW_REPULSION': r'\[VDW_REPULSION\]\s+(\w+)-(\w+):\s+magnitude=([0-9.]+),\s+direction=(\w+)',
            'RING_PUCKER': r'\[RING_PUCKER\]\s+(\w+)\s+ring:\s+magnitude=([0-9.]+),\s+direction=(\w+)'
        }

    def parse_force_description(self, force_text: str, atom_mapping: Dict[str, int], current_coords: torch.Tensor) -> torch.Tensor:
        num_atoms = len(atom_mapping)
        forces = torch.zeros(num_atoms, 3, device=current_coords.device, dtype=current_coords.dtype)

        for force_type, pattern in self.force_patterns.items():
            matches = re.finditer(pattern, force_text)
            for match in matches:
                if force_type == 'BOND_STRETCH':
                    groups = match.groups()
                    if len(groups) >= 8:
                        atom1, atom2, magnitude, direction, atom1_id, vector1, atom2_id, vector2 = groups[:8]
                        mag = float(magnitude)

                        if atom1_id in atom_mapping and atom2_id in atom_mapping:
                            idx1, idx2 = atom_mapping[atom1_id], atom_mapping[atom2_id]
                            try:
                                vec1 = torch.tensor([float(x) for x in vector1.split(',')], device=forces.device, dtype=forces.dtype)
                                vec2 = torch.tensor([float(x) for x in vector2.split(',')], device=forces.device, dtype=forces.dtype)

                                if vec1.size(0) == 3 and vec2.size(0) == 3:
                                    forces[idx1] += mag * vec1
                                    forces[idx2] += mag * vec2
                            except Exception:
                                pass

                elif force_type == 'VDW_REPULSION':
                    atom1, atom2, magnitude, direction = match.groups()
                    mag = float(magnitude)

                    if atom1 in atom_mapping and atom2 in atom_mapping:
                        idx1, idx2 = atom_mapping[atom1], atom_mapping[atom2]
                        bond_vector = current_coords[idx2] - current_coords[idx1]
                        bond_length = torch.norm(bond_vector) + 1e-8
                        unit_vector = bond_vector / bond_length

                        if direction == 'repulsion':
                            forces[idx1] -= mag * unit_vector
                            forces[idx2] += mag * unit_vector
                        else:
                            forces[idx1] += mag * unit_vector
                            forces[idx2] -= mag * unit_vector

        return forces


class SymbolicToNumericalTranslator(nn.Module):
    def __init__(self, hidden_dim: int = 128):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.force_embedding = nn.ModuleDict({
            'type_embedding': nn.Embedding(10, hidden_dim),
            'magnitude_projection': nn.Linear(1, hidden_dim),
            'direction_embedding': nn.Embedding(20, hidden_dim)
        })

        self.force_processor = nn.ModuleDict({
            'attention': nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True),
            'norm1': nn.LayerNorm(hidden_dim),
            'ffn': nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim * 4, hidden_dim)
            ),
            'norm2': nn.LayerNorm(hidden_dim)
        })

        self.force_decoder = nn.ModuleDict({
            'magnitude_head': nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Linear(hidden_dim // 2, 1),
                nn.Softplus()
            ),
            'direction_head': nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Linear(hidden_dim // 2, 3),
                nn.Tanh()
            )
        })

        self.parser = SymbolicForceParser()

    def forward(self, symbolic_forces: str, atom_mapping: Dict[str, int], current_coords: torch.Tensor) -> torch.Tensor:
        parsed_forces = self.parser.parse_force_description(symbolic_forces, atom_mapping, current_coords)

        num_atoms = len(atom_mapping)
        force_types = torch.zeros(num_atoms, dtype=torch.long, device=current_coords.device)
        force_magnitudes = torch.norm(parsed_forces, dim=1, keepdim=True)

        type_emb = self.force_embedding['type_embedding'](force_types)
        mag_emb = self.force_embedding['magnitude_projection'](force_magnitudes)
        dir_emb = self.force_embedding['direction_embedding'](force_types % 20)

        combined_emb = type_emb + mag_emb + dir_emb  # (N, D)

        attn_out, _ = self.force_processor['attention'](combined_emb.unsqueeze(0), combined_emb.unsqueeze(0), combined_emb.unsqueeze(0))
        attn_out = self.force_processor['norm1'](combined_emb.unsqueeze(0) + attn_out)

        ffn_out = self.force_processor['ffn'](attn_out)
        final_emb = self.force_processor['norm2'](attn_out + ffn_out).squeeze(0)

        refined_magnitudes = self.force_decoder['magnitude_head'](final_emb)  # (N,1)
        refined_directions = self.force_decoder['direction_head'](final_emb)  # (N,3)
        refined_directions = F.normalize(refined_directions, dim=1)

        refined_forces = refined_magnitudes * refined_directions  # (N,3)

        return parsed_forces + 0.1 * refined_forces


class LLMSymbolicForceGenerator(nn.Module):
    def __init__(
        self,
        model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
        device_map: str = "auto",
        torch_dtype=torch.float16,
        use_4bit: bool = True,
        enable_lora: bool = False,
        lora_r: int = 16,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
        lora_target_modules: List[str] = None,
    ):
        super().__init__()
        self.model_name = model_name
        self.enable_lora = enable_lora
        if lora_target_modules is None:
            lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

        if use_4bit:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch_dtype,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
        else:
            quantization_config = None

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.llm = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=quantization_config,
            device_map=device_map,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
            attn_implementation="flash_attention_2"
        )

        # 可选 LoRA（QLoRA）适配，用于符号力文本 SFT
        if enable_lora:
            peft_cfg = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                target_modules=lora_target_modules
            )
            self.llm = get_peft_model(self.llm, peft_cfg)
            try:
                self.llm.print_trainable_parameters()
            except Exception:
                pass

        # 编码上下文（可保留，如需进一步融合可用）
        self.chemistry_context_encoder = nn.ModuleDict({
            'graph_encoder': nn.TransformerEncoder(
                nn.TransformerEncoderLayer(256, 8, 1024, batch_first=True),
                num_layers=4
            ),
            'geometry_encoder': nn.TransformerEncoder(
                nn.TransformerEncoderLayer(256, 8, 1024, batch_first=True),
                num_layers=4
            ),
            'fusion_layer': nn.MultiheadAttention(256, 8, batch_first=True)
        })

    def compress_ctx(self, x: torch.Tensor, k: int = 16) -> str:
        x = x.detach().float().flatten()
        if x.numel() > k:
            x = x[:k]
        arr = [f"{v:.4f}" for v in x.cpu().tolist()]
        return ",".join(arr)

    def build_prompt(self, chemical_context: torch.Tensor, geometric_context: torch.Tensor,
                    distances: torch.Tensor, atom_types: List[str], history: List = None) -> str:

        chem_str = self.compress_ctx(chemical_context, 16)
        geom_str = self.compress_ctx(geometric_context, 16)

        prompt_parts = [
            "You are a molecular physics expert. Analyze the current molecular conformation and generate specific force field instructions.",
            f"Molecule contains {len(atom_types)} atoms: {', '.join(atom_types)}",
            f"Chemical context (compressed): [{chem_str}]",
            f"Geometric context (compressed): [{geom_str}]",
            "Current geometric violations detected:",
        ]

        N = distances.shape[0]
        for i in range(N):
            for j in range(i + 1, N):
                dist = distances[i, j].item()
                if dist < 2.0:
                    prompt_parts.append(f"- Atoms {atom_types[i]}{i+1}-{atom_types[j]}{j+1}: distance {dist:.3f}Å (too close)")
                elif dist > 5.0:
                    prompt_parts.append(f"- Atoms {atom_types[i]}{i+1}-{atom_types[j]}{j+1}: distance {dist:.3f}Å (check bonding)")

        if history:
            prompt_parts.append("Previous corrections:")
            for h in history[-3:]:
                prompt_parts.append(f"- {h}")

        prompt_parts.extend([
            "Generate structured force corrections in the following format:",
            "[BOND_STRETCH] C1-C2: magnitude=0.08, direction=compress, force_vector=C1<-0.15,0.0,0.0> C2<+0.15,0.0,0.0>",
            "[VDW_REPULSION] C2-C5: magnitude=0.12, direction=repulsion",
            "[ANGLE_BEND] C1-C2-C3: magnitude=0.05, direction=decrease",
            "Force corrections:"
        ])

        return "\n".join(prompt_parts)

    def forward(self,
                chemical_context: torch.Tensor,
                geometric_context: torch.Tensor,
                distances: torch.Tensor,
                atom_types: List[str],
                history: List = None,
                target_text: Optional[str] = None,
                return_loss: bool = False) -> Dict[str, Optional[torch.Tensor]]:
        """
        训练/推理两用：
        - 若提供 target_text 且 return_loss=True，则进行 SFT（交叉熵）训练，返回 loss 与生成文本；
        - 否则走采样生成，仅返回生成文本。
        """
        prompt = self.build_prompt(chemical_context, geometric_context, distances, atom_types, history)

        if return_loss and target_text is not None:
            # 仅对目标段监督，prompt 段 label 置 -100
            prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
            target_ids = self.tokenizer(target_text, add_special_tokens=False).input_ids + [self.tokenizer.eos_token_id]
            input_ids = torch.tensor([prompt_ids + target_ids], device=self.llm.device)
            labels = torch.tensor([[-100] * len(prompt_ids) + target_ids], device=self.llm.device)
            outputs = self.llm(input_ids=input_ids, labels=labels, return_dict=True)
            return {'loss': outputs.loss, 'text': target_text}
        else:
            with torch.no_grad(), autocast():
                inputs = self.tokenizer(
                    prompt,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=2048
                ).to(self.llm.device)
                outputs = self.llm.generate(
                    **inputs,
                    max_new_tokens=512,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    repetition_penalty=1.1,
                    pad_token_id=self.tokenizer.eos_token_id
                )
                gen_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            return {'loss': None, 'text': gen_text.strip()}

    def save_adapter(self, path: str):
        # 保存 LoRA 适配器或基座模型参数（按 peft 封装）
        try:
            self.llm.save_pretrained(path)
        except Exception as e:
            print(f"Warning: failed to save adapter/model to {path}: {e}")

    def load_adapter(self, path: str):
        # 允许外部加载 LoRA 适配权重
        try:
            from peft import PeftModel
            self.llm = PeftModel.from_pretrained(self.llm, path, is_trainable=True)
        except Exception as e:
            print(f"Warning: failed to load adapter from {path}: {e}")


class SymForceEncoder(nn.Module):
    def __init__(
        self,
        geometric_encoder_config: dict,
        llm_config: dict,
        translator_config: dict,
        physics_constraints: PhysicsConstraints = None,
        max_iterations: int = 20,
        convergence_threshold: float = 1e-3,
    ):
        super().__init__()

        self.max_iterations = max_iterations
        self.convergence_threshold = convergence_threshold
        self.physics_constraints = physics_constraints or PhysicsConstraints()

        self.geometric_encoder = PaiNNGeometricEncoder(**geometric_encoder_config)
        self.symbolic_force_generator = LLMSymbolicForceGenerator(**llm_config)
        self.translator = SymbolicToNumericalTranslator(**translator_config)

        self.step_size_controller = nn.Parameter(torch.tensor(0.01))
        self.momentum_buffer = {}

    def compute_distances(self, coordinates: torch.Tensor) -> torch.Tensor:
        diff = coordinates.unsqueeze(1) - coordinates.unsqueeze(0)
        distances = torch.norm(diff, dim=2)
        return distances

    def compute_molecular_energy(self, coordinates: torch.Tensor, bonds: List[Tuple[int, int]],
                               bond_params: Dict = None) -> torch.Tensor:
        energy = torch.tensor(0.0, device=coordinates.device, dtype=coordinates.dtype)

        for i, j in bonds or []:
            bond_vector = coordinates[j] - coordinates[i]
            bond_length = torch.norm(bond_vector)

            k_bond = bond_params.get((i, j), {}).get('k', 1.0) if bond_params else 1.0
            r0 = bond_params.get((i, j), {}).get('r0', 1.5) if bond_params else 1.5

            energy = energy + k_bond * (bond_length - r0) ** 2

        # 简化非键相互作用排斥项
        N = coordinates.shape[0]
        for i in range(N):
            for j in range(i + 2, N):
                if (bonds and ((i, j) in bonds or (j, i) in bonds)):
                    continue
                dist = torch.norm(coordinates[j] - coordinates[i])
                if dist < 3.0:
                    energy = energy + 1.0 / (dist + 0.1) ** 12

        return energy

    def apply_physics_constraints(self, forces: torch.Tensor) -> torch.Tensor:
        force_magnitudes = torch.norm(forces, dim=1)

        max_allowed = self.physics_constraints.max_bond_force
        scale_factors = torch.clamp(max_allowed / (force_magnitudes + 1e-8), max=1.0)
        constrained_forces = forces * scale_factors.unsqueeze(1)

        total_force = torch.sum(constrained_forces, dim=0)
        constrained_forces = constrained_forces - total_force / constrained_forces.shape[0]

        return constrained_forces

    def adaptive_step_size(self, forces: torch.Tensor, iteration: int, prev_energy: float, current_energy: float) -> float:
        max_force = torch.max(torch.norm(forces, dim=1))
        base_step = 0.01 / (max_force + 1e-6)

        if prev_energy is not None and current_energy > prev_energy:
            decay_factor = 0.5
        else:
            decay_factor = 1.0

        adaptive_step = base_step * decay_factor * (0.95 ** iteration)
        return float(torch.clamp(torch.tensor(adaptive_step, device=forces.device), min=1e-5, max=0.1))

    def forward(self, molecular_graph: dict, initial_coordinates: torch.Tensor,
                atom_types: List[str], bonds: List[Tuple[int, int]] = None,
                symbolic_target: Optional[str] = None, lambda_symbolic: float = 0.0) -> dict:

        coordinates = initial_coordinates.clone()
        history: List[str] = []
        energies: List[float] = []
        force_history: List[torch.Tensor] = []

        atom_mapping = {f"{atom_types[i]}{i+1}": i for i in range(len(atom_types))}

        prev_energy: Optional[float] = None
        symbolic_forces_text: str = ""

        for iteration in range(self.max_iterations):
            z = molecular_graph.get('atomic_numbers', torch.arange(len(atom_types), device=coordinates.device, dtype=torch.long))
            if isinstance(z, torch.Tensor):
                z = z.to(coordinates.device).long()

            chemical_context, geometric_context, s_nodes, v_nodes = self.geometric_encoder(
                z, coordinates
            )

            distances = self.compute_distances(coordinates)

            sym_out = self.symbolic_force_generator(
                chemical_context=chemical_context,
                geometric_context=geometric_context,
                distances=distances,
                atom_types=atom_types,
                history=history[-3:] if history else None,
                target_text=symbolic_target,
                return_loss=(symbolic_target is not None and lambda_symbolic > 0)
            )
            symbolic_forces_text = sym_out['text']
            symbolic_loss = sym_out['loss'] if sym_out['loss'] is not None else torch.tensor(0.0, device=coordinates.device)

            numerical_forces = self.translator(symbolic_forces_text, atom_mapping, coordinates)
            constrained_forces = self.apply_physics_constraints(numerical_forces)

            current_energy = self.compute_molecular_energy(coordinates, bonds or [])
            step_size = self.adaptive_step_size(constrained_forces, iteration, prev_energy, current_energy.item())

            coordinate_update = step_size * constrained_forces
            new_coordinates = coordinates + coordinate_update

            update_magnitude = torch.norm(coordinate_update)
            force_magnitude = torch.norm(constrained_forces)

            coordinates = new_coordinates
            hist_line = f"Iteration {iteration}: force_mag={force_magnitude:.4f}, coord_update={update_magnitude:.4f}"
            if lambda_symbolic > 0 and isinstance(symbolic_loss, torch.Tensor):
                hist_line += f", sym_loss={symbolic_loss.item():.4f}"
            history.append(hist_line)
            energies.append(current_energy.item() + (lambda_symbolic * (symbolic_loss.item() if isinstance(symbolic_loss, torch.Tensor) else 0.0)))
            force_history.append(constrained_forces.clone())

            if update_magnitude < self.convergence_threshold and force_magnitude < self.physics_constraints.force_convergence_threshold:
                break

            prev_energy = current_energy.item()

        return {
            'final_coordinates': coordinates,
            'optimization_history': history,
            'energy_trajectory': energies,
            'force_trajectory': force_history,
            'symbolic_text': symbolic_forces_text,
            'symbolic_loss': symbolic_loss.detach().item() if isinstance(symbolic_loss, torch.Tensor) else 0.0,
            'converged': update_magnitude < self.convergence_threshold,
            'iterations': iteration + 1
        }


def compute_conformation_loss(
    predicted_coords: torch.Tensor,
    target_coords: torch.Tensor,
    predicted_forces: List[torch.Tensor],
    reference_forces: List[torch.Tensor] = None,
    lambda_coord: float = 1.0,
    lambda_force: float = 0.1,
    lambda_physics: float = 0.05
) -> dict:

    coord_loss = F.mse_loss(predicted_coords, target_coords)

    force_loss = torch.tensor(0.0, device=predicted_coords.device, dtype=coord_loss.dtype)
    if predicted_forces and reference_forces:
        for pred_f, ref_f in zip(predicted_forces, reference_forces):
            force_loss = force_loss + F.mse_loss(pred_f, ref_f)
        force_loss = force_loss / max(len(predicted_forces), 1)

    total_force = torch.sum(predicted_forces[-1] if predicted_forces else torch.zeros_like(predicted_coords), dim=0)
    momentum_loss = torch.sum(total_force ** 2)

    min_distances = torch.cdist(predicted_coords, predicted_coords) + torch.eye(len(predicted_coords), device=predicted_coords.device) * 1000
    min_dist = torch.min(min_distances)
    collision_loss = torch.clamp(1.0 - min_dist, min=0.0) ** 2

    physics_loss = momentum_loss + collision_loss

    total_loss = lambda_coord * coord_loss + lambda_force * force_loss + lambda_physics * physics_loss

    return {
        'total_loss': total_loss,
        'coord_loss': coord_loss,
        'force_loss': force_loss,
        'physics_loss': physics_loss
    }