# This file contains modified code from Boltz-1 (https://github.com/jwohlwend/boltz)
# Original code Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro
# Licensed under the MIT License

import torch
from torch import Tensor, nn
from typing import Any, Optional

import coarsebind_public.coarsebind.model.layers.initialize as init
from coarsebind_public.coarsebind.data import const
from coarsebind_public.coarsebind.model.modules.encoders import RelativePositionEncoder
from coarsebind_public.coarsebind.model.modules.trunk import (
    DistogramModule,
    InputEmbedder,
)
from coarsebind_public.coarsebind.model.layers.pairformer import PairformerNoSeqModule


class CoarseBindPF(nn.Module):

    def __init__(  # noqa: PLR0915, C901, PLR0912
        self,
        atom_s: int,
        atom_z: int,
        token_s: int,
        token_z: int,
        num_bins: int,
        training_args: dict[str, Any],
        validation_args: dict[str, Any],
        embedder_args: dict[str, Any],
        msa_args: dict[str, Any],
        pairformer_args: dict[str, Any],
        score_model_args: dict[str, Any],
        diffusion_process_args: dict[str, Any],
        diffusion_loss_args: dict[str, Any],
        confidence_model_args: dict[str, Any],
        atom_feature_dim: int = 128,
        confidence_prediction: bool = False,
        confidence_imitate_trunk: bool = False,
        alpha_pae: float = 0.0,
        structure_prediction_training: bool = True,
        atoms_per_window_queries: int = 32,
        atoms_per_window_keys: int = 128,
        compile_pairformer: bool = False,
        compile_structure: bool = False,
        compile_confidence: bool = False,
        nucleotide_rmsd_weight: float = 5.0,
        ligand_rmsd_weight: float = 10.0,
        no_msa: bool = False,
        no_atom_encoder: bool = False,
        ema: bool = False,
        ema_decay: float = 0.999,
        min_dist: float = 2.0,
        max_dist: float = 22.0,
        predict_args: Optional[dict[str, Any]] = None,
        steering_args: Optional[dict[str, Any]] = None,
        use_kernels: bool = False,
    ) -> None:
        super().__init__()

        # Store configuration
        self.atom_s = atom_s
        self.atom_z = atom_z
        self.token_s = token_s
        self.token_z = token_z
        self.num_bins = num_bins
        self.training_args = training_args
        self.validation_args = validation_args
        self.embedder_args = embedder_args
        self.msa_args = msa_args
        self.pairformer_args = pairformer_args
        self.score_model_args = score_model_args
        self.diffusion_process_args = diffusion_process_args
        self.diffusion_loss_args = diffusion_loss_args
        self.confidence_model_args = confidence_model_args
        self.atom_feature_dim = atom_feature_dim
        self.confidence_prediction = confidence_prediction
        self.confidence_imitate_trunk = confidence_imitate_trunk
        self.alpha_pae = alpha_pae
        self.structure_prediction_training = structure_prediction_training
        self.atoms_per_window_queries = atoms_per_window_queries
        self.atoms_per_window_keys = atoms_per_window_keys
        self.compile_pairformer = compile_pairformer
        self.compile_structure = compile_structure
        self.compile_confidence = compile_confidence
        self.nucleotide_rmsd_weight = nucleotide_rmsd_weight
        self.ligand_rmsd_weight = ligand_rmsd_weight
        self.no_msa = no_msa
        self.no_atom_encoder = no_atom_encoder
        self.ema = ema
        self.ema_decay = ema_decay
        self.min_dist = min_dist
        self.max_dist = max_dist
        self.predict_args = predict_args
        self.steering_args = steering_args
        self.use_kernels = use_kernels

        self.is_pairformer_compiled = False

        # Input dimension
        s_input_dim = const.num_tokens + len(const.pocket_contact_info) + 1280  # esm dimension

        # Initialize layers
        self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
        self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)

        # Input embeddings
        full_embedder_args = {
            "atom_s": atom_s,
            "atom_z": atom_z,
            "token_s": token_s,
            "token_z": token_z,
            "atoms_per_window_queries": atoms_per_window_queries,
            "atoms_per_window_keys": atoms_per_window_keys,
            "atom_feature_dim": atom_feature_dim,
            "no_atom_encoder": no_atom_encoder,
            **embedder_args,
        }
        self.input_embedder = InputEmbedder(**full_embedder_args)

        self.rel_pos = RelativePositionEncoder(token_z)
        self.token_bonds = nn.Linear(1, token_z, bias=False)

        # Normalization layers
        self.z_norm = nn.LayerNorm(token_z)

        # Recycling projections
        self.z_recycle = nn.Linear(token_z, token_z, bias=False)
        init.gating_init_(self.z_recycle.weight)

        self.pairformer_module = PairformerNoSeqModule(token_z, **pairformer_args)

        if compile_pairformer:
            self.is_pairformer_compiled = True
            torch._dynamo.config.cache_size_limit = 512
            torch._dynamo.config.accumulated_cache_size_limit = 512
            self.pairformer_module = torch.compile(
                self.pairformer_module,
                dynamic=False,
                fullgraph=False,
            )

        # Output modules
        self.distogram_module = DistogramModule(token_z, num_bins)

    def setup(self, device: torch.device) -> None:
        """Setup method for kernel checks."""
        if not (
            torch.cuda.is_available()
            and torch.cuda.get_device_properties(device).major >= 8.0  # noqa: PLR2004
        ):
            self.use_kernels = False

    def forward(
        self,
        feats: dict[str, Tensor],
        recycling_steps: int = 3,
        num_sampling_steps: Optional[int] = None,
        multiplicity_diffusion_train: int = 1,
        diffusion_samples: int = 1,
        max_parallel_samples: Optional[int] = None,
        run_confidence_sequentially: bool = False,
    ) -> dict[str, Tensor]:
        dict_out = {}
        device = feats["mol_type"].device

        # Compute input embeddings
        with torch.set_grad_enabled(self.training and self.structure_prediction_training):
            s_inputs = self.input_embedder(feats)

            res_inputs = feats["input_token_embeds"]

            s_inputs = torch.cat([s_inputs, res_inputs], dim=-1)

            # Initialize the sequence and pairwise embeddings
            z_init = self.z_init_1(s_inputs)[:, :, None] + self.z_init_2(s_inputs)[:, None, :]
            relative_position_encoding = self.rel_pos(feats)
            z_init = z_init + relative_position_encoding
            z_init = z_init + self.token_bonds(feats["token_bonds"].float())

            # Perform rounds of the pairwise stack
            z = torch.zeros_like(z_init)

            # Compute pairwise mask
            mask = feats["token_pad_mask"].float() * feats["token_h_mask"].float()
            pair_mask = mask[:, :, None] * mask[:, None, :]

            for i in range(recycling_steps + 1):
                with torch.set_grad_enabled(self.training and (i == recycling_steps)):
                    # Fixes an issue with unused parameters in autocast
                    if self.training and (i == recycling_steps) and torch.is_autocast_enabled():
                        torch.clear_autocast_cache()

                    # Apply recycling
                    z = z_init + self.z_recycle(self.z_norm(z))

                    # Revert to uncompiled version for validation
                    if self.is_pairformer_compiled and not self.training:
                        pairformer_module = self.pairformer_module._orig_mod  # noqa: SLF001
                    else:
                        pairformer_module = self.pairformer_module

                    z = pairformer_module(
                        z,
                        pair_mask=pair_mask,
                        use_kernels=self.use_kernels,
                    )

            pdistogram = self.distogram_module(z)
            dict_out = {
                "pdistogram": pdistogram,
                "s_inputs": s_inputs,
                "z": z,
            }

        return dict_out

    @classmethod
    def load_from_checkpoint(
        cls, checkpoint_path, map_location="cpu", use_kernels=True, strict=False
    ):
        """Load model from Lightning checkpoint.

        This handles checkpoints saved by the Lightning wrapper,
        extracting the state_dict with 'model.' prefix.

        Args:
            checkpoint_path: Path to checkpoint file or file object
            map_location: Device to map tensors to
            use_kernels: Whether to use custom CUDA kernels
            strict: Whether to strictly enforce state_dict keys match

        Returns:
            NoLightning instance with loaded weights
        """
        # Load checkpoint
        if isinstance(checkpoint_path, str):
            ckpt = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
        else:
            # Handle file object (e.g., from cache_read)
            ckpt = torch.load(checkpoint_path, map_location=map_location, weights_only=False)

        # Extract state_dict (handle both old and new checkpoint formats)
        if "state_dict" in ckpt:
            state_dict = ckpt["state_dict"]

            # New format: keys have 'model.' prefix (after refactor)
            if any(k.startswith("model.") for k in state_dict.keys()):
                state_dict = {
                    k.replace("model.", "", 1): v
                    for k, v in state_dict.items()
                    if k.startswith("model.")
                }
            # Old format: keys have no prefix (before refactor)
            # Just use state_dict as-is
        else:
            # Raw state dict
            state_dict = ckpt

        # Extract hyperparameters from checkpoint
        if "hyper_parameters" in ckpt:
            hparams = ckpt["hyper_parameters"]

            # Build kwargs from hparams, only including keys that exist
            # Let __init__ defaults handle any missing values
            kwargs = {}

            # Required parameters (must be present)
            required = [
                "atom_s",
                "atom_z",
                "token_s",
                "token_z",
                "num_bins",
                "training_args",
                "validation_args",
                "embedder_args",
                "msa_args",
                "pairformer_args",
                "score_model_args",
                "diffusion_process_args",
                "diffusion_loss_args",
                "confidence_model_args",
            ]
            for key in required:
                if key in hparams:
                    kwargs[key] = hparams[key]

            # Optional parameters (only add if present)
            optional = [
                "atom_feature_dim",
                "confidence_prediction",
                "confidence_imitate_trunk",
                "alpha_pae",
                "structure_prediction_training",
                "atoms_per_window_queries",
                "atoms_per_window_keys",
                "nucleotide_rmsd_weight",
                "ligand_rmsd_weight",
                "no_msa",
                "no_atom_encoder",
                "ema",
                "ema_decay",
                "min_dist",
                "max_dist",
                "predict_args",
                "steering_args",
            ]
            for key in optional:
                if key in hparams:
                    kwargs[key] = hparams[key]

            # Override use_kernels and disable compilation
            kwargs["use_kernels"] = use_kernels
            kwargs["compile_pairformer"] = False
            kwargs["compile_structure"] = False
            kwargs["compile_confidence"] = False

            model = cls(**kwargs)
        else:
            raise ValueError("No hyperparameters found in checkpoint")

        # Load weights
        model.load_state_dict(state_dict, strict=strict)

        return model
