"""
Unified MolEnc Inferrer.

Molecular encoder supporting dual-mode inference.

Supports two inference modes:
1. Simple mode: SMILES → embeddings (SMILES-based)
2. Allegro mode: SMILES → 3D conformer → embeddings (3D conformer-based)
"""

from typing import List, Optional, Dict
import pickle
import json

import torch
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm
import smart_open

from coarsebind_public.mol_encoder.util.s3.s3_io import cache_read
from coarsebind_public.mol_encoder.util.model_io import CPU_Unpickler
from coarsebind_public.mol_encoder.data.tokenizer.mol_graph import GraphTokenizer
from coarsebind_public.mol_encoder.model.mol_enc import MolEnc
from coarsebind_public.mol_encoder.data.tokenizer.trie_tokenizer import TrieTokenizer
from coarsebind_public.mol_encoder.io_schema import MolEncIOSchema
from coarsebind_public.coarsebind.utils import pad_to_max


def chunk_indexable(iterable, n=1, l=None):
    """
    Chunking when you can index.

    Args:
        iterable (_type_): Any iterable that supports indexing
        n (int, optional): Chunk size. Defaults to 1.
        l (int, optional): Length of iterable. Defaults to None.

    Yields:
        _type_: _description_
    """
    l = len(iterable) if l is None else l
    for ndx in range(0, l, n):
        yield iterable[ndx : min(ndx + n, l)]


def simple_canonicalize(smi: str, allow_error: bool = False) -> str:
    if (mol := Chem.MolFromSmiles(smi)) is not None:
        return Chem.MolToSmiles(mol)
    elif allow_error:
        return None
    else:
        raise ValueError(f"{smi} cannot be parsed. ")


def tokenize_smiles(
    smi,
    tokenizer,
    prefix="[SMILES]",
    extra_spaces=0,
    pad_to_length=None,
    range_check=False,
):
    """
    Tokenizes a SMILES string using a given prefix and a suffix that consists
    of "[STOP]" followed by extra_spaces copies of "[SPACE]".

    If pad_to_length is provided, the token list is padded to that length (or returns None
    if the tokenized sequence is too long).
    """
    if smi is None:
        return None
    # Build the suffix dynamically.
    suffix = "[STOP]" + "[SPACE]" * extra_spaces
    try:
        canon_smi = simple_canonicalize(smi)
        tokens = tokenizer.tokenize_text(
            prefix + canon_smi + suffix, pad=False, range_check=range_check
        )
    except KeyError:
        return None
    except ValueError:
        return None

    if pad_to_length is not None:
        if len(tokens) <= pad_to_length:
            t = torch.zeros(pad_to_length, dtype=torch.long, device="cpu")
            t[: len(tokens)] = torch.tensor(tokens)
            return t
        else:
            return None
    return tokens


def tokenize_c3_smiles(smi: str, tokenizer):
    """
    Tokenizes a single SMILES string with the “c3” scheme:
    Always adds one extra [SPACE] token after [STOP] and pads to tokenizer.n_seq.
    """
    return tokenize_smiles(
        smi, tokenizer, extra_spaces=1, pad_to_length=tokenizer.n_seq, range_check=False
    )


def masked_smiles_embed(smiles, encoder, tokenizer, device: str = "cpu", tokenization="c3"):
    """
    Tokenizes a list of SMILES strings and encodes only those that successfully tokenize.

    The `tokenization` parameter can be set to "c2" or "c3" to choose the desired scheme.
    For "c2" we add 0 extra spaces; for "c3" we add 1 extra space.
    Returns both the embeddings and a boolean mask indicating which SMILES were valid.
    """
    if tokenization == "c2":
        tokenized = [
            tokenize_smiles(
                smi,
                tokenizer,
                extra_spaces=0,
                pad_to_length=tokenizer.n_seq,
                range_check=False,
            )
            for smi in smiles
        ]
    elif tokenization == "c3":
        tokenized = [tokenize_c3_smiles(smi, tokenizer) for smi in smiles]
    else:
        raise ValueError("tokenization must be either 'c2' or 'c3'")

    smi_mask = torch.tensor([t is not None for t in tokenized], dtype=torch.bool, device=device)
    valid_tokens = [t for t in tokenized if t is not None]
    if valid_tokens:
        with torch.no_grad():
            _smiles_emb = encoder.encode_tokens(
                torch.stack(valid_tokens).to(device), tokenizer=tokenizer
            )

        smiles_emb = torch.zeros(smi_mask.shape[0], _smiles_emb.shape[1], device=device)
        smiles_emb[smi_mask] = _smiles_emb

        return smiles_emb, smi_mask
    return None, None


def run_smiles_embeds(
    encoder,
    tokenizer,
    smiles,
    batch_size=128,
    tokenization="c3",
    prog_bar: bool = False,
):

    smiles_emb = []
    smi_mask = []

    for batch in tqdm(
        chunk_indexable(smiles, batch_size),
        total=len(smiles) // batch_size,
        desc="run_smiles_embeds",
        disable=not prog_bar,
    ):
        batch_emb, batch_mask = masked_smiles_embed(
            batch, encoder, tokenizer, device=encoder.device, tokenization=tokenization
        )
        smiles_emb.extend(batch_emb.detach().cpu().numpy())
        smi_mask.extend(batch_mask.detach().cpu().numpy())

    return smiles_emb, smi_mask


def load_mol_enc(
    doc_url: str,
    device: str = "cpu",
    freeze: bool = True,
    strict: bool = False,
    old_architecture=False,
    override_args=None,  # hopefully not needed, but you never know.
    print_debug=False,
    cache_file=False,  # save from s3 to disk
    force_cpu=False,  # needed to deserialize on some cpu-only machines
    model_type="default",
    smiles_tokenizer_path: str = None,
    graph_tokenizer_path: str = None,
):
    if smiles_tokenizer_path is None:
        raise ValueError(
            "smiles_tokenizer_path is required. Please provide path to tokenizer vocabulary."
        )
    if graph_tokenizer_path is None:
        raise ValueError(
            "graph_tokenizer_path is required. Please provide path to graph tokenizer vocabulary."
        )

    print(f"Loading model from {doc_url}")

    _open = cache_read if cache_file else smart_open.open
    with _open(doc_url, "rb") as f_in:
        if force_cpu:
            model_doc = CPU_Unpickler(f_in, encoding="UTF-8").load()
        else:
            model_doc = pickle.loads(f_in.read(), encoding="UTF-8")
    model_kwargs = model_doc["model_kwargs"]

    model_dict_ = model_doc["model"]
    new_names = [
        k.replace("module.", "") if k.startswith("module.") else k for k in model_dict_.keys()
    ]
    state_dict = {new_name: t for new_name, t in zip(new_names, model_dict_.values())}

    gtokenizer = GraphTokenizer(vocab_path=graph_tokenizer_path)
    gtokenizer.set_vocab()

    if "device" in model_kwargs:
        model_kwargs["device"] = device

    if override_args:
        model_kwargs.update(override_args)
    if model_type == "default" or model_type is None:
        model = MolEnc(**model_kwargs)
    else:
        raise ValueError(
            f"Unknown model type {model_type}. options are 'default' or 'graphless' and  'no_g_enc'"
        )

    model.load_state_dict(state_dict, strict=strict)
    model.to(device)
    model.device = device

    with smart_open.open(smiles_tokenizer_path, "r") as f_in:
        tokenizer_vocab = json.load(f_in)

    tokenizer = TrieTokenizer(n_seq=model_kwargs["n_seq"], **tokenizer_vocab)
    print(model_doc["train_args"])

    if freeze:
        print("Freezing encoder")
        n_params = 0
        for param in model.parameters():
            param.requires_grad = False
            n_params += param.numel()
        print(f"{n_params } params frozen!")

    return model, tokenizer, gtokenizer, None


def compute_3d_conformer(mol: Chem.Mol, version: str = "v3") -> bool:
    """Generate 3D coordinates using ETKDG method.

    Taken from `pdbeccdutils.core.component.Component`.

    Parameters
    ----------
    mol: Mol
        The RDKit molecule to process
    version: str, optional
        The ETKDG version, defaults to v3

    Returns
    -------
    bool
        Whether computation was successful.
    """
    if version == "v3":
        options = AllChem.ETKDGv3()
    elif version == "v2":
        options = AllChem.ETKDGv2()
    else:
        options = AllChem.ETKDGv2()

    # fixed seed
    options.randomSeed = 0xF00D
    options.clearConfs = False
    conf_id = -1

    try:
        conf_id = AllChem.EmbedMolecule(mol, options)

        if conf_id == -1:
            print(
                f"WARNING: RDKit ETKDGv3 failed to generate a conformer for molecule "
                f"{Chem.MolToSmiles(AllChem.RemoveHs(mol))}, so the program will start with random coordinates. "
                f"Note that the performance of the model under this behaviour was not tested."
            )
            options.useRandomCoords = True
            conf_id = AllChem.EmbedMolecule(mol, options)

        AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)

    except RuntimeError:
        pass  # Force field issue here
    except ValueError:
        pass  # sanitization issue here

    if conf_id != -1:
        conformer = mol.GetConformer(conf_id)
        conformer.SetProp("name", "Computed")
        conformer.SetProp("coord_generation", f"ETKDG{version}")
        return True

    return False


def batch_to_device(batch: Dict[str, torch.Tensor], device: str) -> Dict[str, torch.Tensor]:
    """
    Moves a batch of tensors to the specified device.

    Args:
        batch: A dictionary of tensors.
        device: The target device (e.g., 'cpu', 'cuda:0').

    Returns:
        A new dictionary with tensors moved to the device.
    """
    new_batch = {}
    for k in batch:
        if batch[k] is None:
            continue
        new_batch[k] = batch[k].to(device, non_blocking=True)
    return new_batch


def crop_hydrogens(io_schema: MolEncIOSchema):
    """
    Remove hydrogen atoms from embeddings and atom features.

    Args:
        io_schema: Input/output schema with embeddings and atom features
    """
    if io_schema.e3nn_embed is None:
        return

    is_hydrogen = io_schema.atoms == 1
    is_hydrogen_bond = (io_schema.atoms[io_schema.bonds] == 1).any(axis=1)

    io_schema.bonds = io_schema.bonds[~is_hydrogen_bond]
    io_schema.e3nn_embed = io_schema.e3nn_embed[~is_hydrogen]
    io_schema.atoms = io_schema.atoms[~is_hydrogen]
    io_schema.atom_toks = io_schema.atom_toks[~is_hydrogen]
    io_schema.triu = None


class MolEncInferrer:
    """
    Unified Inferrer for dual-mode inference.

    Always generates both embeddings:
    - smiles_embed: Fast SMILES-based embedding (no 3D conformer)
    - e3nn_embed: High-quality 3D conformer-based embedding (point encoder)
    """

    def __init__(
        self,
        model_uri: str,
        device: str,
        smiles_tokenizer_path: str,
        graph_tokenizer_path: str,
    ):
        """
        Initialize inferrer.

        Args:
            model_uri: S3 path or local path to model artifacts
            device: Device to run inference on ('cpu', 'cuda', etc.)
            smiles_tokenizer_path: Path to SMILES tokenizer vocabulary (JSON file)
            graph_tokenizer_path: Path to graph tokenizer vocabulary (PKL file)
        """
        self.model_uri = model_uri
        self.device = device
        self.smiles_tokenizer_path = smiles_tokenizer_path
        self.graph_tokenizer_path = graph_tokenizer_path

        # Model components (initialized lazily)
        self.encoder_simple = None
        self.encoder_allegro = None
        self.tokenizer = None
        self.graph_tokenizer = None
        self.xform_routine = None

    def _load_model(self):
        """Load the MolEnc model components for both simple and allegro modes."""
        if self.encoder_simple is not None and self.encoder_allegro is not None:
            return  # Already loaded

        # Load model once for both encoders
        encoder, tokenizer, gtokenizer, xform_routine = load_mol_enc(
            self.model_uri,
            print_debug=False,
            device=self.device,
            cache_file=True,
            force_cpu=True,
            smiles_tokenizer_path=self.smiles_tokenizer_path,
            graph_tokenizer_path=self.graph_tokenizer_path,
        )

        # Set up simple encoder
        self.encoder_simple = encoder.eval()
        self.tokenizer = tokenizer
        self.graph_tokenizer = gtokenizer
        self.xform_routine = xform_routine

        # Set up allegro encoder (point encoder from the same model)
        self.encoder_allegro = encoder.point_encoder.to(self.device)
        self.encoder_allegro.eval()

    def predict_smiles(
        self,
        smiles: List[str],
        batch_size: int = 128,
        prog_bar: bool = True,
    ) -> List[MolEncIOSchema]:
        """
        Run SMILES-based inference.

        Args:
            smiles: List of SMILES strings
            batch_size: Batch size for inference
            prog_bar: Whether to show progress bar

        Returns:
            List of MolEncIOSchema objects with smiles_embed populated
        """
        if self.encoder_simple is None or self.tokenizer is None:
            self._load_model()

        smiles_emb, smi_mask = run_smiles_embeds(
            self.encoder_simple,
            self.tokenizer,
            smiles,
            batch_size=batch_size,
            tokenization="c3",
            prog_bar=prog_bar,
        )

        # Convert to IO schema
        results = []
        for smi, emb, mask in zip(smiles, smiles_emb, smi_mask):
            result = MolEncIOSchema(
                smiles=smi,
                canon_smiles=Chem.MolToSmiles(Chem.MolFromSmiles(smi)),
                smiles_embed=emb,
                embed_mask=mask,
                artifact_s3=self.model_uri,
            )
            results.append(result)

        return results

    def _initial_checks(self, row: MolEncIOSchema):
        """Perform initial validation and canonicalization."""
        row.artifact_s3 = self.model_uri
        row.canon_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(row.smiles))

    def _assign_conformer(self, row: MolEncIOSchema):
        """Generate 3D conformer for molecule."""
        mol = Chem.MolFromSmiles(row.canon_smiles)
        mol = Chem.AddHs(mol)

        success = compute_3d_conformer(mol, version="v3")

        if not success:
            raise ValueError("3D conformer generation failed")

        row.rdkit_mol = mol

    def _assign_mol_enc_features(self, row: MolEncIOSchema):
        """Extract MolEnc features from molecule."""

        def get_triu(bonds, bond_types, num_atoms):
            # Initialize adjacency matrix with zeros
            adjacency_matrix = np.zeros((num_atoms, num_atoms), dtype=int)

            if len(bonds) == 0:
                return adjacency_matrix[np.triu_indices(num_atoms, k=1)]

            # Ensure bond indices are within bounds
            max_index_in_bonds = bonds[:, :2].max()
            if max_index_in_bonds >= num_atoms:
                raise ValueError(f"Atom index {max_index_in_bonds} exceeds maximum {num_atoms - 1}")

            # Populate adjacency matrix with bond types
            for (i, j), bond_type in zip(bonds, bond_types):
                adjacency_matrix[i, j] = bond_type
                adjacency_matrix[j, i] = bond_type

            # Extract upper triangular part (excluding diagonal)
            upper_triangle = adjacency_matrix[np.triu_indices(num_atoms, k=1)]
            return upper_triangle

        atom_toks, bond_toks, bonds = self.graph_tokenizer.tokens_from_mol(
            row.rdkit_mol, return_coords=False, ignore_light=False
        )
        bonds = np.array(bonds)
        atom_toks = np.array(atom_toks)
        bond_toks = np.array(bond_toks)
        atoms = np.array([a.GetAtomicNum() for a in row.rdkit_mol.GetAtoms()])

        num_atoms = 180
        triu = get_triu(bonds, bond_toks, num_atoms)

        row.atom_toks = atom_toks
        row.atoms = atoms
        row.bonds = bonds
        row.triu = triu

    def _featurize(self, io_schema: List[MolEncIOSchema]):
        """Featurize molecules for Allegro inference."""
        for row in io_schema:
            try:
                self._initial_checks(row)
                self._assign_conformer(row)
                self._assign_mol_enc_features(row)
            except Exception as e:
                row.error = True
                row.error_msg = str(e)

    def _collate_fn(self, batch: List[MolEncIOSchema]) -> Dict[str, torch.Tensor]:
        """Collate batch of IO schemas into tensors."""
        atoms = [torch.from_numpy(row.atoms) for row in batch]
        coords = [
            torch.from_numpy(row.rdkit_mol.GetConformer().GetPositions()).to(torch.float32)
            for row in batch
        ]

        allegro_atom_toks = [torch.from_numpy(row.atom_toks) for row in batch]
        allegro_triu = [torch.from_numpy(row.triu) for row in batch]

        atoms, _ = pad_to_max(atoms, value=0)
        coords, _ = pad_to_max(coords, value=0.0)
        allegro_atom_toks, _ = pad_to_max(allegro_atom_toks, value=0)
        allegro_triu, _ = pad_to_max(allegro_triu, value=0)

        features = {
            "atoms": atoms,
            "coords": coords,
            "atom_toks": allegro_atom_toks,
            "triu": allegro_triu,
        }

        return features

    @torch.no_grad()
    def predict_allegro(
        self,
        io_schema: List[MolEncIOSchema],
        batch_size: int = 128,
        crop_hydrogens_flag: bool = True,
    ) -> List[MolEncIOSchema]:
        """
        Run Allegro 3D conformer-based inference.

        Args:
            io_schema: List of MolEncIOSchema objects with SMILES
            batch_size: Batch size for inference
            crop_hydrogens_flag: Whether to remove hydrogens from output

        Returns:
            Updated list of MolEncIOSchema objects with e3nn_embed populated
        """
        if self.encoder_allegro is None or self.graph_tokenizer is None:
            self._load_model()

        self._featurize(io_schema)

        orig_idxs = [i for i, row in enumerate(io_schema) if not row.error]

        for i in tqdm(range(0, len(orig_idxs), batch_size)):
            batch_idxs = orig_idxs[i : i + batch_size]
            batch_rows = [io_schema[idx] for idx in batch_idxs]

            features = self._collate_fn(batch_rows)
            features = batch_to_device(features, self.device)

            Y, H_even, H_odd = self.encoder_allegro(
                features["atoms"],
                features["coords"],
                features["atom_toks"],
                features["triu"],
                apply_stop_decode=False,
            )
            embeds = torch.cat([Y, H_even, H_odd], -1)
            embeds = embeds.cpu().numpy()

            for _row, _embed in zip(batch_rows, embeds):
                num_atoms = len(_row.atoms)
                _row.e3nn_embed = _embed[:num_atoms]
                if crop_hydrogens_flag:
                    crop_hydrogens(_row)

        return io_schema

    def predict(
        self,
        input_data: List[str] | List[MolEncIOSchema],
        batch_size: int = 128,
        crop_hydrogens_flag: bool = True,
        prog_bar: bool = True,
    ) -> List[MolEncIOSchema]:
        """
        Run dual-mode inference to generate both embeddings.

        Args:
            input_data: Either list of SMILES strings or list of MolEncIOSchema objects
            batch_size: Batch size for inference
            crop_hydrogens_flag: Whether to remove hydrogens from allegro embeddings
            prog_bar: Whether to show progress bar

        Returns:
            List of MolEncIOSchema objects with both embeddings:
            - smiles_embed: SMILES-based embedding (fast, no 3D)
            - e3nn_embed: 3D conformer-based embedding (slower, geometric)
        """
        # Convert SMILES strings to IO schema if needed
        if isinstance(input_data[0], str):
            smiles = input_data
            io_schema = [MolEncIOSchema(smiles=smi) for smi in smiles]
        else:
            io_schema = input_data
            smiles = [row.smiles for row in io_schema]

        # Run simple mode first
        print("Running simple mode...")
        simple_results = self.predict_smiles(smiles, batch_size, prog_bar)

        # Transfer smiles_embed to io_schema
        for src, dst in zip(simple_results, io_schema):
            dst.smiles_embed = src.smiles_embed
            dst.embed_mask = src.embed_mask
            if dst.canon_smiles is None:
                dst.canon_smiles = src.canon_smiles
            if dst.artifact_s3 is None:
                dst.artifact_s3 = src.artifact_s3

        # Run allegro mode
        print("Running allegro mode...")
        return self.predict_allegro(io_schema, batch_size, crop_hydrogens_flag)
