from typing import Dict, List, Tuple, Any, Callable, Optional

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from coarsebind_public.mol_encoder.util.model_io import CPU_Unpickler
from coarsebind_public.mol_encoder.util.s3.s3_io import cache_read

from coarsebind_public.coarsebind.model.layers.pairformer import PairformerNoSeqModule
import coarsebind_public.coarsebind.model.layers.initialize as init
from coarsebind_public.coarsebind.model.models.base_epinet import Basic_Epinet


class Transition(nn.Module):
    """Perform a two-layer MLP."""

    def __init__(
        self,
        dim: int = 128,
        hidden: int = 512,
        out_dim: Optional[int] = None,
    ) -> None:
        """Initialize the TransitionUpdate module.

        Parameters
        ----------
        dim: int
            The dimension of the input, default 128
        hidden: int
            The dimension of the hidden, default 512
        out_dim: Optional[int]
            The dimension of the output, default None

        """
        super().__init__()
        if out_dim is None:
            out_dim = dim

        self.norm = nn.LayerNorm(dim, eps=1e-5)
        self.fc1 = nn.Linear(dim, hidden, bias=False)
        self.fc2 = nn.Linear(dim, hidden, bias=False)
        self.fc3 = nn.Linear(hidden, out_dim, bias=False)
        self.silu = nn.SiLU()
        self.hidden = hidden

        init.bias_init_one_(self.norm.weight)
        init.bias_init_zero_(self.norm.bias)

        init.lecun_normal_init_(self.fc1.weight)
        init.lecun_normal_init_(self.fc2.weight)
        init.final_init_(self.fc3.weight)

    def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
        """Perform a forward pass.

        Parameters
        ----------
        x: torch.Tensor
            The input data of shape (..., D)

        Returns
        -------
        x: torch.Tensor
            The output data of shape (..., D)

        """
        x = self.norm(x)

        if chunk_size is None or self.training:
            x = self.silu(self.fc1(x)) * self.fc2(x)
            x = self.fc3(x)
            return x
        else:
            # Compute in chunks
            for i in range(0, self.hidden, chunk_size):
                fc1_slice = self.fc1.weight[i : i + chunk_size, :]
                fc2_slice = self.fc2.weight[i : i + chunk_size, :]
                fc3_slice = self.fc3.weight[:, i : i + chunk_size]
                x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
                if i == 0:
                    x_out = x_chunk @ fc3_slice.T
                else:
                    x_out = x_out + x_chunk @ fc3_slice.T
            return x_out


class PairwiseConditioning(nn.Module):
    """Algorithm 21."""

    def __init__(
        self,
        token_z,
        dim_token_rel_pos_feats,
        num_transitions=2,
        transition_expansion_factor=2,
    ):
        super().__init__()

        self.dim_pairwise_init_proj = nn.Sequential(
            nn.LayerNorm(token_z + dim_token_rel_pos_feats),
            nn.Linear(token_z + dim_token_rel_pos_feats, token_z, bias=False),
        )

        transitions = nn.ModuleList([])
        for _ in range(num_transitions):
            transition = Transition(dim=token_z, hidden=transition_expansion_factor * token_z)
            transitions.append(transition)

        self.transitions = transitions

    def forward(
        self,
        z_trunk,  # Float['b n n tz'],
        token_rel_pos_feats,  # Float['b n n 3'],
    ):  # -> Float['b n n tz']:
        z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
        z = self.dim_pairwise_init_proj(z)

        for transition in self.transitions:
            z = transition(z) + z

        return z


class PotencyModule(nn.Module):

    def __init__(
        self,
        pair_dim: int,
        input_s_dim: int,
        s_dim: int,
        pairformer_kwargs: Dict[str, Any],
    ):
        super().__init__()

        self.pair_dim = pair_dim
        self.s_dim = s_dim
        self.input_s_dim = input_s_dim
        self.pairformer_kwargs = pairformer_kwargs

        num_bins = 64

        self.z_norm = nn.LayerNorm(pair_dim)
        self.z_linear = nn.Linear(pair_dim, pair_dim, bias=False)

        self.s_to_z_proj_1 = nn.Linear(self.input_s_dim, pair_dim, bias=False)
        self.s_to_z_proj_2 = nn.Linear(self.input_s_dim, pair_dim, bias=False)

        self.mol_enc_norm = nn.LayerNorm(768)
        self.mol_enc_proj = nn.Linear(768, pair_dim, bias=False)

        self.bins_linear = nn.Linear(num_bins, num_bins, bias=False)

        self.pairwise_conditioner = PairwiseConditioning(
            token_z=pair_dim,
            dim_token_rel_pos_feats=num_bins,
            num_transitions=2,
        )

        self.pairformer_stack = PairformerNoSeqModule(pair_dim, **pairformer_kwargs)

        self.affinity_out_mlp = nn.Sequential(
            nn.Linear(pair_dim * 3, s_dim),
            nn.SiLU(),
            nn.Linear(s_dim, s_dim),
            nn.SiLU(),
        )

        self.to_quant_affinity = nn.Sequential(
            nn.Linear(s_dim, s_dim),
            nn.SiLU(),
            nn.Linear(s_dim, s_dim),
            nn.SiLU(),
            nn.Linear(s_dim, 1),
        )

        self.to_binary_affinity = nn.Sequential(
            nn.Linear(s_dim, s_dim),
            nn.SiLU(),
            nn.Linear(s_dim, s_dim),
            nn.SiLU(),
            nn.Linear(s_dim, 1),
            nn.Linear(1, 1),
        )

    def forward(
        self,
        feats: Dict[str, torch.Tensor],
        use_kernels: bool = False,
    ) -> Dict[str, torch.Tensor]:

        lig_mask = feats["potency_ligand_mask"]
        res_mask = ~feats["potency_ligand_mask"] * feats["valid_mask"]

        lig_lig_mask = lig_mask[:, :, None] * lig_mask[:, None, :]
        lig_res_mask = (
            lig_mask[:, :, None] * res_mask[:, None, :]
            + res_mask[:, :, None] * lig_mask[:, None, :]
        )
        all_pair_mask = lig_lig_mask + lig_res_mask

        lig_lig_mask = lig_lig_mask.to(torch.int32)
        lig_res_mask = lig_res_mask.to(torch.int32)
        all_pair_mask = all_pair_mask.to(torch.int32)

        # Process z features through linear layers first to get correct dtype
        z_processed = self.z_linear(self.z_norm(feats["z"]))[feats["z_mask"]]

        z = torch.zeros(
            (
                all_pair_mask.shape[0],
                all_pair_mask.shape[1],
                all_pair_mask.shape[2],
                z_processed.shape[-1],
            ),
            dtype=z_processed.dtype,
            device=z_processed.device,
        )

        z[all_pair_mask.bool()] = z_processed

        # Process bin_probs through linear layers first to get correct dtype
        bins_processed = self.bins_linear(feats["bin_probs"])[feats["z_mask"]]

        dist_bins = torch.zeros(
            (
                all_pair_mask.shape[0],
                all_pair_mask.shape[1],
                all_pair_mask.shape[2],
                bins_processed.shape[-1],
            ),
            dtype=bins_processed.dtype,
            device=bins_processed.device,
        )

        dist_bins[all_pair_mask.bool()] = bins_processed

        s_pair_embeds = (
            self.s_to_z_proj_1(feats["s_inputs"])[:, :, None, :]
            + self.s_to_z_proj_2(feats["s_inputs"])[:, None, :, :]
        )

        mol_enc_embed = self.mol_enc_proj(self.mol_enc_norm(feats["mol_enc"]))
        mol_enc_embed = mol_enc_embed[:, None, None, :].repeat(1, z.shape[1], z.shape[2], 1)
        mol_enc_embed = mol_enc_embed * lig_lig_mask.unsqueeze(-1)

        # add mol_enc_embed to lig-lig pairs only. Note diff batch indexes can have diff num lig tokens

        z = z + s_pair_embeds + mol_enc_embed

        z = z + self.pairwise_conditioner(z, dist_bins)

        z = self.pairformer_stack(
            z,
            pair_mask=all_pair_mask,
            use_kernels=use_kernels,
        )

        all_pair_mask = all_pair_mask.unsqueeze(-1)
        lig_lig_mask = lig_lig_mask.unsqueeze(-1)
        lig_res_mask = lig_res_mask.unsqueeze(-1)

        g_1 = torch.sum(z * all_pair_mask, dim=(1, 2)) / (
            torch.sum(all_pair_mask, dim=(1, 2)) + 1e-7
        )
        g_2 = torch.sum(z * lig_res_mask, dim=(1, 2)) / (torch.sum(lig_res_mask, dim=(1, 2)) + 1e-7)
        g_3 = torch.sum(z * lig_lig_mask, dim=(1, 2)) / (torch.sum(lig_lig_mask, dim=(1, 2)) + 1e-7)

        # concat
        g = torch.cat([g_1, g_2, g_3], dim=-1)

        g = self.affinity_out_mlp(g)

        pred_quant = self.to_quant_affinity(g).squeeze(-1)
        pred_binary_logits = self.to_binary_affinity(g).squeeze(-1)

        output = {
            "pred_quant": pred_quant,
            "pred_binary": F.sigmoid(pred_binary_logits),
            "pred_binary_logits": pred_binary_logits,
            "epinet_latent": g.detach(),
        }

        return output


class CoarseBindAffinity(torch.nn.Module):

    def __init__(
        self,
        base_pairformer: str,
        pair_dim: int,
        input_s_dim: int,
        s_dim: int,
        pairformer_kwargs: Dict[str, Any],
        potency_cutoff_dist: float = 15.0,
        chunk_size: int = 256,
        mol_enc_uri: str = "",
        epinet_kwargs: Dict[str, Any] = None,
        train_epinet: bool = False,
    ):

        super().__init__()

        self.base_pairformer = base_pairformer
        self.mol_enc_uri = mol_enc_uri
        self.potency_cutoff_dist = potency_cutoff_dist
        self.pair_dim = pair_dim
        self.chunk_size = chunk_size
        self.train_epinet = train_epinet

        if self.train_epinet:
            # freeze non-epinet model parameters
            for param in self.parameters():
                param.requires_grad = False

            self.epinet_model = Basic_Epinet(**epinet_kwargs)
        else:
            self.epinet_model = None

        self.potency_model = PotencyModule(
            pair_dim=pair_dim,
            s_dim=s_dim,
            input_s_dim=input_s_dim,
            pairformer_kwargs=pairformer_kwargs,
        )

    def train(self, mode: bool = True):
        """Sets the module in training mode."""
        super().train(mode=mode)
        if self.train_epinet:
            # set non-epinet model parts to eval mode
            self.potency_model.eval()

            # freeze non-epinet parameters
            for name, param in self.named_parameters():

                if not name.startswith("epinet"):
                    param.requires_grad = False

        return self

    def forward(
        self,
        feats: Dict[str, torch.Tensor],
        epinet_samples: int = None,
        residual_scale: float = 1.5,
    ):

        output = self.potency_model(feats)

        if self.epinet_model is not None and epinet_samples is None and self.training:
            # random epi index sample
            residual = self.epinet_model(output["epinet_latent"])
            pred_quant = output["pred_quant"] + residual
        elif self.epinet_model is not None and epinet_samples is not None and epinet_samples > 1:
            residual = self.epinet_model.sample_n(
                output["epinet_latent"],
                n_samples=epinet_samples,
                residual_scale=residual_scale,
            )
            quant_mean = output["pred_quant"].unsqueeze(0).repeat(epinet_samples, 1)
            pred_quant = (quant_mean + residual).T
        else:
            pred_quant = output["pred_quant"]

        if self.epinet_model is not None and self.training:
            # create randn epilatents
            rng_epi_latents = torch.randn(
                (pred_quant.shape[0], self.epinet_model.feature_dim),
                device=pred_quant.device,
            )
            epi_samples = self.epinet_model.sample_n(
                rng_epi_latents, n_samples=100, residual_scale=1.0, use_sobol=False
            )
            epi_std = epi_samples.T.std(dim=1)
            output["epi_rng_std"] = epi_std

        # TODO epinet as sep output?
        output["pred_quant"] = pred_quant

        return output

    @classmethod
    def from_artifact(cls, artifact_uri: str) -> "CoarseBindAffinity":
        """
        Instantiate a CoarseBind_5_Disto_1 model from a model document pickle file path.
        """

        with cache_read(artifact_uri, "rb") as f:
            model_doc = CPU_Unpickler(f, encoding="UTF-8").load()

        model_kwargs = model_doc["trainer_args"]["model"]
        model_kwargs.pop("_target_")

        model = cls(**model_kwargs)

        model.load_state_dict(model_doc["model_state_dict"])
        model.eval()
        return model
