import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, LlamaForCausalLM, PreTrainedModel, GenerationMixin
from peft import get_peft_model, LoraConfig, TaskType
import numpy as np
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
import json
from collections import defaultdict
from tqdm import tqdm
import random

try:
    from rdkit import Chem
    from rdkit.Chem import AllChem

    RDKIT_AVAILABLE = True
except ImportError:
    RDKIT_AVAILABLE = False

try:
    from openbabel import pybel

    OPENBABEL_AVAILABLE = True
except ImportError:
    OPENBABEL_AVAILABLE = False

from torch_geometric.data import Data, Batch
from symforce_encoder import SymForceEncoder, PhysicsConstraints, compute_conformation_loss


def disabled_train(self, mode=True):
    return self


@dataclass
class SymForceConfig:
    geometric_encoder_config: dict
    llm_config: dict
    translator_config: dict
    physics_constraints: PhysicsConstraints
    max_iterations: int = 20
    convergence_threshold: float = 1e-3

    llm_model: str = "meta-llama/Llama-3.1-8B-Instruct"
    torch_dtype: str = "float16"
    enable_flash: bool = True

    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    lora_target_modules: List[str] = None

    def __post_init__(self):
        if self.lora_target_modules is None:
            self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']


class SymForcePreTrainedModel(PreTrainedModel):
    config_class = SymForceConfig
    base_model_prefix = 'symforce'
    supports_gradient_checkpointing = True
    _keys_to_ignore_on_load_missing = [
        r"position_ids",
        r"encoder\..*",
        r"llm\..*"
    ]


class ConformationGenerator3D:
    @staticmethod
    def generate_rdkit_conformation(smiles: str) -> Tuple[Optional[List[str]], Optional[np.ndarray]]:
        if not RDKIT_AVAILABLE:
            return None, None

        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return None, None

            num_atoms = mol.GetNumAtoms()
            mol = Chem.AddHs(mol)

            embed_params = AllChem.ETKDGv3()
            embed_params.randomSeed = 42
            embed_params.numThreads = 8
            embed_params.pruneRmsThresh = 1.0
            embed_params.maxAttempts = 10000
            embed_params.useRandomCoords = False

            conf_ids = AllChem.EmbedMultipleConfs(mol, numConfs=1, params=embed_params)

            if len(conf_ids) == 0:
                return None, None

            try:
                AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=8)
            except Exception:
                pass

            mol = Chem.RemoveHs(mol)

            if num_atoms != mol.GetNumAtoms():
                return None, None

            if mol.GetNumConformers() == 0:
                return None, None

            atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
            coordinates = np.array(mol.GetConformer().GetPositions(), dtype=np.float32)

            return atoms, coordinates

        except Exception:
            return None, None

    @staticmethod
    def generate_openbabel_conformation(smiles: str) -> Tuple[Optional[List[str]], Optional[np.ndarray]]:
        if not OPENBABEL_AVAILABLE:
            return None, None

        try:
            mol = pybel.readstring('smi', smiles)
            mol.make3D(forcefield='mmff94', steps=10000)
            mol.OBMol.DeleteHydrogens()

            atomic_nums = [atom.atomicnum for atom in mol.atoms]

            if RDKIT_AVAILABLE:
                pt = Chem.GetPeriodicTable()
                atoms = [pt.GetElementSymbol(atomic_num) for atomic_num in atomic_nums]
            else:
                element_symbols = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 35: 'Br',
                                   53: 'I'}
                atoms = [element_symbols.get(atomic_num, 'X') for atomic_num in atomic_nums]

            coordinates = np.array([atom.coords for atom in mol.atoms], dtype=np.float32)

            return atoms, coordinates

        except Exception:
            return None, None

    @classmethod
    def generate_conformation(cls, smiles: str) -> Tuple[Optional[List[str]], Optional[np.ndarray]]:
        atoms, coordinates = cls.generate_rdkit_conformation(smiles)

        if atoms is None or coordinates is None:
            atoms, coordinates = cls.generate_openbabel_conformation(smiles)

        return atoms, coordinates


class SMILES2GraphConverter:
    @staticmethod
    def atom_features(atom):
        if not RDKIT_AVAILABLE:
            return [0] * 9

        return [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetImplicitValence(),
            int(atom.GetHybridization()),
            int(atom.GetIsAromatic()),
            atom.GetFormalCharge(),
            int(atom.IsInRing()),
            int(atom.GetChiralTag()),
            atom.GetTotalNumHs()
        ]

    @staticmethod
    def bond_features(bond):
        if not RDKIT_AVAILABLE:
            return [0] * 4

        return [
            int(bond.GetBondType()),
            int(bond.GetIsConjugated()),
            int(bond.IsInRing()),
            int(bond.GetStereo())
        ]

    @classmethod
    def smiles_to_graph(cls, smiles: str) -> Dict:
        if not RDKIT_AVAILABLE:
            return {
                'node_feat': torch.zeros(1, 9),
                'edge_index': torch.zeros(2, 0, dtype=torch.long),
                'edge_feat': torch.zeros(0, 4)
            }

        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return {
                'node_feat': torch.zeros(1, 9),
                'edge_index': torch.zeros(2, 0, dtype=torch.long),
                'edge_feat': torch.zeros(0, 4)
            }

        node_features = []
        for atom in mol.GetAtoms():
            node_features.append(cls.atom_features(atom))

        edge_indices = []
        edge_features = []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([[start, end], [end, start]])

            bond_feat = cls.bond_features(bond)
            edge_features.extend([bond_feat, bond_feat])

        return {
            'node_feat': torch.tensor(node_features, dtype=torch.float),
            'edge_index': torch.tensor(edge_indices, dtype=torch.long).T if edge_indices else torch.zeros(2, 0, dtype=torch.long),
            'edge_feat': torch.tensor(edge_features, dtype=torch.float) if edge_features else torch.zeros(0, 4)
        }


class MolecularDataProcessor:
    def __init__(self, max_atoms: int = 100):
        self.max_atoms = max_atoms
        self.converter = SMILES2GraphConverter()
        self.conformer_generator = ConformationGenerator3D()

    def process_smiles_list(self, smiles_list: List[str], device: torch.device) -> Dict:
        processed_data = defaultdict(list)

        for idx, smiles in enumerate(tqdm(smiles_list, desc='Processing molecules')):
            atoms, coordinates = self.conformer_generator.generate_conformation(smiles)

            if atoms is None or coordinates is None:
                print(f"Failed to generate conformation for SMILES {idx}: {smiles}")
                continue

            if len(atoms) > self.max_atoms:
                print(f"Molecule {idx} has {len(atoms)} atoms, exceeding limit of {self.max_atoms}")
                continue

            graph_data = self.converter.smiles_to_graph(smiles)

            processed_data['coordinates'].append(torch.tensor(coordinates, dtype=torch.float32))
            processed_data['atom_types'].append(atoms)
            processed_data['atomic_numbers'].append(
                torch.tensor([self._get_atomic_number(atom) for atom in atoms], dtype=torch.long))
            processed_data['graphs'].append(Data(
                x=graph_data['node_feat'],
                edge_index=graph_data['edge_index'],
                edge_attr=graph_data['edge_feat']
            ))
            processed_data['smiles'].append(smiles)

        batched_data = {}
        if processed_data['graphs']:
            batched_data['graph_batch'] = Batch.from_data_list(processed_data['graphs']).to(device)
            # 均保持为 list，避免原子数不一致导致 stack 失败
            batched_data['coordinates'] = [t.to(device) for t in processed_data['coordinates']]
            batched_data['atomic_numbers'] = [an.to(device) for an in processed_data['atomic_numbers']]
            batched_data['atom_types'] = processed_data['atom_types']
            batched_data['smiles'] = processed_data['smiles']

        return batched_data

    def _get_atomic_number(self, element: str) -> int:
        atomic_numbers = {
            'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53
        }
        return atomic_numbers.get(element, 0)


class SymForce(SymForcePreTrainedModel):
    def __init__(self, config: SymForceConfig, vocab_size: int = None):
        super().__init__(config)

        self.encoder = SymForceEncoder(
            geometric_encoder_config=config.geometric_encoder_config,
            llm_config=config.llm_config,
            translator_config=config.translator_config,
            physics_constraints=config.physics_constraints,
            max_iterations=config.max_iterations,
            convergence_threshold=config.convergence_threshold
        )

        if config.torch_dtype == "bfloat16":
            torch_dtype = torch.bfloat16
        elif config.torch_dtype == "float16":
            torch_dtype = torch.float16
        else:
            torch_dtype = torch.float32

        model_kwargs = {
            'torch_dtype': torch_dtype,
            'trust_remote_code': True
        }

        if config.enable_flash:
            model_kwargs['attn_implementation'] = "flash_attention_2"

        self.llm = LlamaForCausalLM.from_pretrained(config.llm_model, **model_kwargs)

        if vocab_size:
            self.llm.resize_token_embeddings(vocab_size)

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=config.lora_target_modules
        )

        self.llm = get_peft_model(self.llm, peft_config)
        try:
            self.llm.print_trainable_parameters()
        except Exception:
            pass

        self.coordinate_projection = nn.ModuleDict({
            'coord_encoder': nn.Sequential(
                nn.Linear(3, 128),
                nn.ReLU(),
                nn.Linear(128, 256),
                nn.ReLU(),
                nn.Linear(256, self.llm.config.hidden_size)
            ),
            'position_embedding': nn.Embedding(1000, self.llm.config.hidden_size)
        })

        self.molecular_processor = MolecularDataProcessor()

    def forward(self,
                smiles_batch: List[str],
                text_batch: dict,
                target_coordinates: Optional[List[torch.Tensor]] = None,
                symbolic_targets_batch: Optional[List[str]] = None,
                lambda_symbolic: float = 0.0) -> dict:

        device = next(self.parameters()).device
        molecular_data = self.molecular_processor.process_smiles_list(smiles_batch, device)

        if not molecular_data:
            return {'loss': torch.tensor(0.0, device=device)}

        optimization_results: List[Dict[str, Any]] = []
        total_loss = torch.tensor(0.0, device=device)

        for i, (coords, atom_types, atomic_nums) in enumerate(zip(
                molecular_data['coordinates'],
                molecular_data['atom_types'],
                molecular_data['atomic_numbers']
        )):
            molecular_graph = {
                'atomic_numbers': atomic_nums,
                'coordinates': coords
            }

            result = self.encoder(
                molecular_graph=molecular_graph,
                initial_coordinates=coords,
                atom_types=atom_types,
                bonds=[],
                symbolic_target=(symbolic_targets_batch[i] if symbolic_targets_batch and i < len(symbolic_targets_batch) else None),
                lambda_symbolic=lambda_symbolic
            )

            optimization_results.append(result)

            if target_coordinates is not None and i < len(target_coordinates):
                loss_components = compute_conformation_loss(
                    predicted_coords=result['final_coordinates'],
                    target_coords=target_coordinates[i].to(device),
                    predicted_forces=result['force_trajectory'],
                    reference_forces=None
                )
                sym_loss_tensor = torch.tensor(result.get('symbolic_loss', 0.0), device=device, dtype=loss_components['total_loss'].dtype)
                total_loss = total_loss + loss_components['total_loss'] + lambda_symbolic * sym_loss_tensor

        if text_batch and 'input_ids' in text_batch:
            coordinate_embeddings = []
            for result in optimization_results:
                coord_emb = self.coordinate_projection['coord_encoder'](result['final_coordinates'])
                coordinate_embeddings.append(coord_emb.mean(dim=0, keepdim=True))   # (1,H)

            if coordinate_embeddings:
                mol_embeddings = torch.cat(coordinate_embeddings, dim=0)            # (B,H)

                input_ids = text_batch['input_ids'].to(device)
                attention_mask = text_batch.get('attention_mask')
                if attention_mask is not None:
                    attention_mask = attention_mask.to(device)
                inputs_embeds = self.llm.get_input_embeddings()(input_ids).clone()  # (B,T,H)

                # 批量注入：支持每条样本多个 <mol> 位置
                if 'mol_token_positions' in text_batch:
                    positions_per_sample = text_batch['mol_token_positions']
                    B, T, H = inputs_embeds.size()
                    for b in range(min(B, mol_embeddings.size(0))):
                        pos_list = positions_per_sample[b] if isinstance(positions_per_sample[b], (list, tuple)) else [positions_per_sample[b]]
                        for pos in pos_list:
                            if 0 <= pos < T:
                                inputs_embeds[b, pos, :] = mol_embeddings[b]

                llm_outputs = self.llm(
                    inputs_embeds=inputs_embeds,
                    attention_mask=attention_mask,
                    labels=text_batch.get('labels'),
                    return_dict=True
                )

                if llm_outputs.loss is not None:
                    total_loss = total_loss + llm_outputs.loss

        return {
            'loss': total_loss / max(len(optimization_results), 1),
            'optimization_results': optimization_results,
            'molecular_data': molecular_data
        }

    @torch.no_grad()
    def generate_conformations(
            self,
            smiles_list: List[str],
            max_iterations: int = None,
            convergence_threshold: float = None
    ) -> List[Dict]:

        device = next(self.parameters()).device
        molecular_data = self.molecular_processor.process_smiles_list(smiles_list, device)

        if not molecular_data:
            return []

        results = []
        original_max_iter = self.encoder.max_iterations
        original_threshold = self.encoder.convergence_threshold

        if max_iterations is not None:
            self.encoder.max_iterations = max_iterations
        if convergence_threshold is not None:
            self.encoder.convergence_threshold = convergence_threshold

        try:
            for i, (coords, atom_types, atomic_nums, smiles) in enumerate(zip(
                    molecular_data['coordinates'],
                    molecular_data['atom_types'],
                    molecular_data['atomic_numbers'],
                    molecular_data['smiles']
            )):
                molecular_graph = {
                    'atomic_numbers': atomic_nums,
                    'coordinates': coords
                }

                result = self.encoder(
                    molecular_graph=molecular_graph,
                    initial_coordinates=coords,
                    atom_types=atom_types,
                    bonds=[]
                )

                result['smiles'] = smiles
                result['initial_coordinates'] = coords
                results.append(result)

        finally:
            self.encoder.max_iterations = original_max_iter
            self.encoder.convergence_threshold = original_threshold

        return results

    @torch.no_grad()
    def generate_with_text(
            self,
            smiles_list: List[str],
            text_batch: dict,
            generation_kwargs: dict = None
    ) -> dict:

        device = next(self.parameters()).device
        molecular_data = self.molecular_processor.process_smiles_list(smiles_list, device)

        if not molecular_data:
            return {'generated_ids': torch.empty(0), 'conformations': []}

        conformation_results = []
        coordinate_embeddings = []

        for coords, atom_types, atomic_nums in zip(
                molecular_data['coordinates'],
                molecular_data['atom_types'],
                molecular_data['atomic_numbers']
        ):
            molecular_graph = {
                'atomic_numbers': atomic_nums,
                'coordinates': coords
            }

            result = self.encoder(
                molecular_graph=molecular_graph,
                initial_coordinates=coords,
                atom_types=atom_types,
                bonds=[]
            )

            conformation_results.append(result)

            coord_emb = self.coordinate_projection['coord_encoder'](result['final_coordinates'])
            coordinate_embeddings.append(coord_emb.mean(dim=0, keepdim=True))

        if coordinate_embeddings and text_batch and 'input_ids' in text_batch:
            mol_embeddings = torch.cat(coordinate_embeddings, dim=0)               # (B,H)
            input_ids = text_batch['input_ids'].to(device)
            attention_mask = text_batch.get('attention_mask')
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)
            inputs_embeds = self.llm.get_input_embeddings()(input_ids).clone()

            if 'mol_token_positions' in text_batch:
                positions_per_sample = text_batch['mol_token_positions']
                B, T, H = inputs_embeds.size()
                for b in range(min(B, mol_embeddings.size(0))):
                    pos_list = positions_per_sample[b] if isinstance(positions_per_sample[b], (list, tuple)) else [positions_per_sample[b]]
                    for pos in pos_list:
                        if 0 <= pos < T:
                            inputs_embeds[b, pos, :] = mol_embeddings[b]

            default_gen_kwargs = {
                'max_new_tokens': 512,
                'do_sample': True,
                'temperature': 0.7,
                'top_p': 0.9,
                'repetition_penalty': 1.1
            }

            if generation_kwargs:
                default_gen_kwargs.update(generation_kwargs)

            generated_ids = self.llm.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                **default_gen_kwargs
            )

            return {
                'generated_ids': generated_ids,
                'conformations': conformation_results,
                'molecular_data': molecular_data
            }

        return {
            'generated_ids': torch.empty(0),
            'conformations': conformation_results,
            'molecular_data': molecular_data
        }

    def load_from_checkpoint(self, checkpoint_path: str, strict: bool = False):
        print(f"Loading SymForce checkpoint from: {checkpoint_path}")

        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
        except TypeError:
            # 兼容老版 torch 无 weights_only 参数
            checkpoint = torch.load(checkpoint_path, map_location='cpu')

        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint

        filtered_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith('symforce.'):
                new_key = key[9:]
                filtered_state_dict[new_key] = value
            else:
                filtered_state_dict[key] = value

        missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=strict)

        if not strict:
            expected_missing = ['llm.', 'encoder.geometric_encoder.', 'encoder.symbolic_force_generator.llm.']
            filtered_missing = [k for k in missing_keys if not any(k.startswith(prefix) for prefix in expected_missing)]

            if filtered_missing:
                print(f"Warning: Unexpected missing keys: {filtered_missing}")

        if unexpected_keys:
            print(f"Warning: Unexpected keys in checkpoint: {unexpected_keys}")

        print("Successfully loaded SymForce checkpoint")

    @staticmethod
    def set_seed(seed: int = 42):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    @staticmethod
    def _best_rmsd_rdkit(ref_xyz: np.ndarray, pred_xyz: np.ndarray, atomic_nums: List[int]) -> float:
        """
        计算对齐后的 heavy-atom RMSD（Kabsch）。若 RDKit 可用可扩展为更强的 BestRMS（含对称性）。
        要求：ref_xyz (N,3), pred_xyz (N,3)，按相同原子顺序。
        """
        heavy = np.array([z > 1 for z in atomic_nums], dtype=bool)
        X = ref_xyz[heavy]; Y = pred_xyz[heavy]
        if X.shape[0] == 0:
            return float('nan')
        Xc = X - X.mean(0); Yc = Y - Y.mean(0)
        C = np.dot(Yc.T, Xc)
        V, S, Wt = np.linalg.svd(C)
        d = (np.linalg.det(V) * np.linalg.det(Wt)) < 0.0
        if d:
            V[:, -1] = -V[:, -1]
        U = np.dot(V, Wt)
        Y2 = np.dot(Yc, U)
        return float(np.sqrt(((Xc - Y2) ** 2).sum() / len(X)))

    @staticmethod
    def compute_amr(reference_confs: List[np.ndarray],
                    generated_confs: List[np.ndarray],
                    atomic_nums: List[int]) -> float:
        if len(reference_confs) == 0 or len(generated_confs) == 0:
            return float("inf")
        per_ref = []
        for ref in reference_confs:
            best = min(SymForce._best_rmsd_rdkit(ref, gen, atomic_nums) for gen in generated_confs)
            per_ref.append(best)
        return float(np.nanmean(per_ref))