from rdkit import RDLogger

RDLogger.DisableLog("rdApp.*")
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import autocast


from coarsebind_public.mol_encoder.data.tokenizer.trie_tokenizer import _tokenize_smiles
from coarsebind_public.mol_encoder.models.encoder_3d.attentive_allegro import AttentiveAllegro
from coarsebind_public.mol_encoder.models.loose_modules.activations import SwiGLU
from coarsebind_public.mol_encoder.models.loose_modules.simple_nets import SwiGLUResNet
from coarsebind_public.mol_encoder.models.encoder_graph.graph_transformer import AttentiveGCN
from coarsebind_public.mol_encoder.models.transformer.rotary_smiles_xformer import (
    SmilesTransformerConfig,
    RotarySmilesTransformer,
)


class loss_fn(nn.Module):
    """
    TODO: implement https://arxiv.org/abs/2303.15343
    """

    def __init__(self):
        super().__init__()

    def forward(self, smiles_features, conformer_features, bad_rows=None):
        logits_per_smiles = smiles_features @ conformer_features.T
        logits_per_conformer = conformer_features @ smiles_features.T
        num_logits = logits_per_smiles.shape[0]
        labels = torch.arange(num_logits, device=smiles_features.device, dtype=torch.long)
        if not bad_rows is None:
            labels = torch.where(bad_rows, -1 * torch.ones_like(labels), labels)
        total_loss = (
            F.cross_entropy(logits_per_smiles, labels, ignore_index=-1)
            + F.cross_entropy(logits_per_conformer, labels, ignore_index=-1)
        ) / 2
        return total_loss.unsqueeze(0)  # for dataparallel.


class MolEnc(nn.Module):
    """
    - adds graph.
    - allows for more than one [stop] to be decoded per unit of encoder
      these are concatenated into dim_encoder * n_out_tokens vectors.
      Each will become n_out_tokens injected tokens.
    """

    def __init__(
        self,
        # Xformer parameters
        n_layer_xformer=16,
        n_hidden_xformer=256,
        embed_dim=256,
        n_head=16,
        n_seq=80,
        mlp_dropout=0.0,
        # parameters related to injection or contrast.
        n_out_tokens=1,
        enc_to_proj="swiglu_mlp",
        n_direct_clr=64,  # n_dim to take from the representation for the directCLR loss.
        # 3d parameters.
        n_layer_3d=3,
        n_hidden_3d=256,
        encoder_3d="allegro",  # options = allegro_xform2 // allegro_xform // e3gnn
        code_chirality=True,
        msg_cutoff_3d=5.0,
        scalar_tensor_ratio=8,
        irreps_edge_sh_="1x0e+1x1o+1x2e",
        irreps_out_="1x0e+1x0o+1x1e+1x1o+1x2e+1x2o",
        # graph parameters.
        max_node_types=50,
        max_edge_types=50,
        n_layer_graph=2,
        graph_xformer_type="1",
        do_lap_pe=False,
        n_tok=4,  # this is a hack to pickle num toks processed during training.
        biases=True,
        device=torch.device("cpu"),
        dtype=torch.float,
        dim_node=256,
        dim_edge=256,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.code_chirality = code_chirality
        self.encoder_3d = encoder_3d
        self.enc_to_proj = enc_to_proj
        self.n_direct_clr = n_direct_clr
        self.n_out_tokens = n_out_tokens

        if graph_xformer_type == "J":
            self.gxformer = AttentiveGCN(
                n_layers=n_layer_graph,
                dim_node=dim_node,
                dim_edge=dim_edge,
                dim_global=n_hidden_xformer,
                max_node_types=max_node_types,
                max_edge_types=max_edge_types,
                n_out_tokens=n_out_tokens,
                do_lap_pe=do_lap_pe,
            )
            self.nemb_raw_graph = n_hidden_xformer * n_out_tokens
        elif graph_xformer_type == "M2":
            raise Exception(
                "M2 no longer supported. never worked well. M now uses the same setup as M2 and old M is deprecated"
            )
        else:
            raise Exception("unk graph enc type")

        if self.encoder_3d == "allegro_xform2":
            print("dim_scalar", n_hidden_3d)
            print("dim_equivariant", n_hidden_3d // scalar_tensor_ratio)
            print("cutoff", msg_cutoff_3d)
            print("irreps_edge_sh_", irreps_edge_sh_)
            print("irreps_out_", irreps_out_)
            print("max_node_types", max_node_types)
            print("max_edge_types", max_edge_types)
            print("n_out_tokens", n_out_tokens)
            self.point_encoder = AttentiveAllegro(
                layers_allegro=n_layer_3d,
                dim_scalar=n_hidden_3d,
                dim_equivariant=n_hidden_3d // scalar_tensor_ratio,
                cutoff=msg_cutoff_3d,
                irreps_edge_sh_=irreps_edge_sh_,
                irreps_out_=irreps_out_,
                max_node_types=max_node_types,
                max_edge_types=max_edge_types,
                n_out_tokens=n_out_tokens,
            )
            self.nemb_raw_3d = n_hidden_3d * n_out_tokens
        else:
            raise Exception("Unknown encoder type", self.encoder_3d)

        kwargs = {
            "n_layer": n_layer_xformer,
            "n_embd": n_hidden_xformer,
            "n_head": n_head,
            "n_seq": n_seq,
            "n_tok": n_tok,
            "device": device,
            "dtype": dtype,
            "biases": biases,
            "n_stop_tokens": n_out_tokens,
        }
        self.xformer_config = SmilesTransformerConfig(**kwargs)
        self.xformer = RotarySmilesTransformer(self.xformer_config)
        self.device = device

        self.nemb_raw_xformer = n_out_tokens * n_hidden_xformer
        if enc_to_proj == "linear":
            self.point_to_proj = nn.Sequential(
                nn.LayerNorm(self.nemb_raw_3d),
                nn.Linear(self.nemb_raw_3d, self.nemb_raw_xformer),
            )
            self.smiles_to_proj = nn.Sequential(
                nn.LayerNorm(self.nemb_raw_xformer),
                nn.Linear(self.nemb_raw_xformer, self.nemb_raw_xformer),
            )
            self.graph_to_proj = nn.Sequential(
                nn.LayerNorm(self.nemb_raw_graph),
                nn.Linear(self.nemb_raw_graph, self.nemb_raw_xformer),
            )
        # Make the common representation
        elif enc_to_proj == "swiglu_mlp":
            self.point_to_proj = nn.Sequential(
                nn.LayerNorm(self.nemb_raw_3d),
                nn.Linear(self.nemb_raw_3d, 2 * self.nemb_raw_xformer),
                SwiGLU(),
                nn.Linear(self.nemb_raw_xformer, self.nemb_raw_xformer),
            )
            self.smiles_to_proj = nn.Sequential(
                nn.LayerNorm(self.xformer.n_embd),
                nn.Linear(self.xformer.n_embd, 2 * self.nemb_raw_xformer),
                SwiGLU(),
                nn.Linear(self.nemb_raw_xformer, self.nemb_raw_xformer),
            )
            self.graph_to_proj = nn.Sequential(
                nn.LayerNorm(self.nemb_raw_graph),
                nn.Linear(self.nemb_raw_graph, 2 * self.nemb_raw_xformer),
                SwiGLU(),
                nn.Linear(self.nemb_raw_xformer, self.nemb_raw_xformer),
            )
        elif enc_to_proj == "swiglu_resnet":
            self.point_to_proj = SwiGLUResNet(
                self.nemb_raw_3d, self.nemb_raw_xformer, dropout=mlp_dropout
            )
            self.smiles_to_proj = SwiGLUResNet(
                self.nemb_raw_xformer, self.nemb_raw_xformer, dropout=mlp_dropout
            )
            self.graph_to_proj = SwiGLUResNet(
                self.nemb_raw_graph, self.nemb_raw_xformer, dropout=mlp_dropout
            )

        # self.proj_to_token = SwiGLUResNet( self.nemb_raw_xformer, self.nemb_raw_xformer)
        self.proj_to_token = torch.nn.Identity()

        n_params_3d = sum(p.numel() for p in self.point_encoder.parameters())
        n_params_graph = sum(p.numel() for p in self.gxformer.parameters())
        n_params_smiles = sum(p.numel() for p in self.xformer.parameters())
        n_params = n_params_3d + n_params_smiles + n_params_graph
        print(
            f"number of parameters Total: {n_params_3d/1e6:.2f}M graph: {n_params_graph/1e6:.2f}M xformer: {n_params_smiles/1e6:.2f}M Total: {n_params/1e6:.2f}M "
        )
        self.loss_fn = loss_fn()
        self.to(self.device)

    def encode_tokens(self, token_indices, tokenizer):
        assert token_indices.dim() == 2
        return self.smiles_to_proj(self.xformer.encode(token_indices, tokenizer))

    def encode_from_emb(self, emb, token_indices, tokenizer):
        return self.smiles_to_proj(self.xformer.encode_from_emb(emb, token_indices, tokenizer))

    def encode_graph(self, atoms, nodes, edges, graph_tokenizer):
        return self.graph_to_proj(self.gxformer(atoms, nodes, edges, graph_tokenizer))

    def encode_points(self, atoms, coords, nodes=None, edges=None, distance_gradient=False):
        """
        Notice the agg. of the point encoder was previously done here.
        Now is done within the encoder itself.
        """
        assert atoms.dim() == 2
        assert coords.dim() == 3
        if self.encoder_3d == "allegro_xform2" or self.encoder_3d == "allegro_xform3":
            X = self.point_encoder(atoms, coords, nodes, edges)  # batch X atoms X point_output_dim
            return self.point_to_proj(X)
        elif self.encoder_3d == "allegro" or self.encoder_3d == "allegro_xform":
            XeXo = self.point_encoder(atoms, coords).mean(1)  # batch X atoms X point_output_dim
            return self.point_to_proj(XeXo)
        elif self.encoder_3d == "e3gnn":
            return self.point_to_proj(
                self.point_encoder(atoms, coords, distance_gradient=distance_gradient)
            )

    def points_to_2d(
        self,
        atoms,
        coords,
        tokenizer,
        nodes=None,
        edges=None,
        fill_in_from="[SMILES]",
        noise_scale=0.0,
        inv_temp=2,
        k=100,
    ):
        """
        Testing generation of SMILES (or GRAPH)
        from atoms and coords
        """
        assert fill_in_from == "[SMILES]" or fill_in_from == "[GRAPH]"
        h_proj = self.encode_points(atoms, coords, nodes=nodes, edges=edges)
        if noise_scale > 0:
            h_proj += torch.normal(
                mean=torch.zeros_like(h_proj),
                std=noise_scale * torch.ones_like(h_proj),
            )
        h_token = self.proj_to_token(h_proj)
        # create a 'batch' to infer smiles.
        clip_string = "".join(["[CLIP][UNK]"] + ["[SPACE]" for k in range(self.n_out_tokens - 1)])
        token_prebatch = tokenizer.tokenize_text(
            clip_string + fill_in_from + "[SUFFIX][MIDDLE]", pad=False
        )
        generation = self.xformer.generate_topk_with_inj(
            prefix=token_prebatch,
            stop_token=tokenizer.stop_token,
            inv_temp=inv_temp,
            k=k,
            inj_token=tokenizer.unk_token,
            inj_payload=h_token[0],
        )
        if fill_in_from == "[SMILES]":
            return tokenizer.decode(generation, special=False)
        else:
            return tokenizer.decode(generation)

    def h_proj_to_2d(
        self,
        h_proj,
        tokenizer,
        fill_in_from="[SMILES]",
        noise_scale=0.0,
        do_suffix=False,
        inv_temp=2,
        k=100,
    ):
        """
        Testing generation of SMILES (or GRAPH)
        from atoms and coords
        """
        assert fill_in_from == "[SMILES]" or fill_in_from == "[GRAPH]"
        if noise_scale > 0:
            h_proj += torch.normal(
                mean=torch.zeros_like(h_proj),
                std=noise_scale * torch.ones_like(h_proj),
            )
        h_token = self.proj_to_token(h_proj)
        # create a 'batch' to infer smiles.
        if do_suffix:
            suffstr = "[SUFFIX][MIDDLE]"
        else:
            suffstr = ""
        clip_string = "".join(["[CLIP][UNK]"] + ["[SPACE]" for k in range(self.n_out_tokens - 1)])
        token_prebatch = tokenizer.tokenize_text(clip_string + fill_in_from + suffstr, pad=False)
        generation = self.xformer.generate_topk_with_inj(
            prefix=token_prebatch,
            stop_token=tokenizer.stop_token,
            inv_temp=inv_temp,
            k=k,
            inj_token=tokenizer.unk_token,
            inj_payload=h_token,
        )
        if fill_in_from == "[SMILES]":
            return tokenizer.decode(generation, special=False)
        else:
            return tokenizer.decode(generation)

    def h_proj_to_2d_batch(
        self,
        h_proj: torch.Tensor,
        tokenizer,
        fill_in_from: str = "[SMILES]",
        noise_scale: float = 0.0,
        inv_temp: float = 2,
        k: int = 100,
        do_suffix=False,
        keep_special: bool = False,
        return_tokens: bool = False,
    ):
        """
        Testing generation of SMILES (or GRAPH)
        from atoms and coords
        """
        # assert k > 1
        if noise_scale > 0:
            h_proj += torch.normal(
                mean=torch.zeros_like(h_proj),
                std=noise_scale * torch.ones_like(h_proj),
            )
        h_token = self.proj_to_token(h_proj)
        if do_suffix:
            suffstr = "[SUFFIX][MIDDLE]"
        else:
            suffstr = ""
        clip_string = "".join(
            ["[CLIP][UNK]"] + ["[SPACE]" for ii in range(self.n_out_tokens - 1)]
        )
        # clip_string = '[CLIP][UNK]'
        token_prebatch = tokenizer.tokenize_text(clip_string + fill_in_from + suffstr, pad=False)
        assert h_token.dim() == 2
        assert h_token.shape[-1] == self.xformer.n_embd * self.n_out_tokens
        generation = self.xformer.generate_top_k_with_inj_batch(
            prefix=token_prebatch,
            stop_token=tokenizer.stop_token,
            inv_temp=inv_temp,
            k=k,
            pad_token=tokenizer.pad_token,
            inj_token=tokenizer.unk_token,
            inj_payload=h_token,
        )
        smiles_list = [
            tokenizer.decode(token_out, special=keep_special) for token_out in generation
        ]

        if return_tokens:
            return smiles_list, generation

        return smiles_list

    def points_to_2d_batch(
        self,
        atom_batch,
        coords_batch,
        tokenizer,
        nodes=None,
        edges=None,
        fill_in_from="[SMILES]",
        noise_scale=0.0,
        do_suffix=False,
        inv_temp=2,
        k=100,
        keep_special=False,
    ):
        """
        Testing generation of SMILES (or GRAPH)
        from atoms and coords
        """
        h_proj = self.encode_points(atom_batch, coords_batch, nodes=nodes, edges=edges)
        if noise_scale > 0:
            h_proj += torch.normal(
                mean=torch.zeros_like(h_proj),
                std=noise_scale * torch.ones_like(h_proj),
            )
        h_token = self.proj_to_token(h_proj)
        if do_suffix:
            suffstr = "[SUFFIX][MIDDLE]"
        else:
            suffstr = ""
        token_prebatch = tokenizer.tokenize_text("[CLIP][UNK]" + fill_in_from + suffstr, pad=False)
        generation = self.xformer.generate_top_k_with_inj_batch(
            prefix=token_prebatch,
            stop_token=tokenizer.stop_token,
            inv_temp=inv_temp,
            k=k,
            pad_token=tokenizer.pad_token,
            inj_token=tokenizer.unk_token,
            inj_payload=h_token,
        )
        smiles_list = [
            tokenizer.decode(token_out, special=keep_special) for token_out in generation
        ]
        return smiles_list

    def h_proj_and_tokens_to_likelihood(
        self, h_proj, smiles, tokenizer, prefix="[CLIP][UNK][SMILES][SUFFIX][MIDDLE]"
    ):
        """
        Simply computes the likelihood that h_proj decodes to a given smiles.
        """
        clip_string = "".join(["[CLIP]"] + ["[UNK]" for k in range(self.n_out_tokens)])
        tokens = torch.tensor(
            tokenizer.tokenize_text(
                clip_string + "[SMILES][SUFFIX][MIDDLE]" + smiles + "[STOP]", pad=False
            ),
            device=h_proj.device,
            dtype=torch.long,
        ).unsqueeze(0)
        y_next = torch.zeros_like(tokens)
        y_next[:, : (tokens.shape[1] - 1)] = tokens[:, 1:].clone()
        y_next[y_next == tokenizer.clip_token] = -1
        y_next[y_next == tokenizer.pad_token] = -1
        y_next[y_next == tokenizer.smiles_token] = -1
        y_next[y_next == tokenizer.unk_token] = -1
        y_next[y_next == tokenizer.suffix_token] = -1
        y_next[y_next == tokenizer.middle_token] = -1
        logits = self.xformer.forward_with_replacement(
            tokens, self.proj_to_token(h_proj.unsqueeze(0)), tokenizer
        )
        ar_loss_ = torch.nn.functional.cross_entropy(
            logits.view(-1, logits.size(-1)),
            y_next.view(-1),
            ignore_index=-1,
            reduction="none",
        ).reshape(tokens.shape)
        ar_loss_[y_next == -1] = 0
        return ar_loss_.sum(-1)

    def complete_batch(
        self, prefixes, tokenizer, inv_temp=2, k=100, keep_special=False, de_fim=True
    ):
        """
        Testing generation of SMILES
        from atoms and coords
        """
        # create a 'batch' to infer smiles.
        tokens = [tokenizer.tokenize_text(p, pad=False) for p in prefixes]
        generation = self.xformer.generate_topk_batch(
            prefix=tokens,
            stop_token=tokenizer.stop_token,
            pad_token=tokenizer.pad_token,
            inv_temp=inv_temp,
            k=k,
        )
        smiles_list = [
            tokenizer.decode(token_out, special=keep_special, de_fim=de_fim)
            for token_out in generation
        ]
        return smiles_list

    def batch_smiles_to_s2s_likelihood(self, smiles, tokenizer):
        """Simply computes the likelihood that SMILES->h_proj->SMILES decodes for all SMILES in a list of `smiles`"""
        # make tokens from '<smi>[STOP]'
        _tokens = [
            _tokenize_smiles(
                smi,
                tokenizer,
                prefix="",
                suffix="[STOP][SPACE]",
                device=self.device,
                max_size=(tokenizer.n_seq - 4),
            )
            for smi in smiles
        ]
        tokenizes_mask = torch.tensor(
            [False if t is None else True for t in _tokens],
            dtype=torch.bool,
            device=self.device,
        )
        _tokens = torch.stack([t for t in _tokens if t is not None]).to(self.device)

        # make embeddings from '[SMILES]<smi>[STOP]'
        h_proj_tokens = torch.zeros(
            _tokens.shape[0],
            _tokens.shape[1] + 1,  # leave space for [SMILES]
            dtype=torch.long,
            device=self.device,
        )
        h_proj_tokens[:, 0] = tokenizer.smiles_token
        h_proj_tokens[:, 1:] = _tokens
        h_proj = self.encode_tokens(h_proj_tokens, tokenizer)

        # make logits from '[CLIP][UNK][SPACE][SMILES]<smi>[STOP]' and h_proj from [SMILES]<smi>[STOP]
        tokens = torch.zeros(
            _tokens.shape[0],
            _tokens.shape[1] + 4,  # leave space for [CLIP][UNK][SPACE][SMILES]
            dtype=torch.long,
            device=self.device,
        )
        tokens[:, 0] = tokenizer.clip_token
        tokens[:, 1] = tokenizer.unk_token
        tokens[:, 2] = tokenizer.space_token
        tokens[:, 3] = tokenizer.smiles_token
        tokens[:, 4:] = _tokens

        logits = self.xformer.forward_with_replacement(  # pred token
            tokens, self.proj_to_token(h_proj), tokenizer
        )

        # calculate cross entropy on pred token and actual next token
        mask_val = -1  # to mask loss for special tokens
        next_tokens = torch.zeros_like(tokens)
        next_tokens[:, : (tokens.shape[1] - 1)] = tokens[:, 1:].clone()
        next_tokens[:, :3] = (
            mask_val  # '[UNK][SMILES][SUFFIX][MIDDLE]', [CLIP] not present because it was first in tokens
        )
        next_tokens[:, -1] = (
            mask_val  # because of the shift and next_tokens construction, next_tokens[:, -1] is [0, 0, ...], ensure this is pad token
        )
        next_tokens[next_tokens == tokenizer.pad_token] = mask_val

        # find ar_loss per SMILES
        ar_loss_ = (
            torch.nn.functional.cross_entropy(
                logits.view(-1, logits.shape[2]),
                next_tokens.view(-1),
                ignore_index=mask_val,
                reduction="none",
            )
            .view(next_tokens.shape[0], next_tokens.shape[1])
            .sum(axis=1)
        )
        return ar_loss_, tokenizes_mask

    def forward(
        self,
        raw_tokens,
        augmented_tokens,
        atoms,
        coords,
        nodes,
        edges,
        tokenizer,
        gtokenizer,
        p_emb_graph=1.0 / 3,
        p_emb_3d=1.0 / 3,
    ):
        """
        Same as the below routine but for DistributedDataParallel training.
        """
        with autocast(enabled=False, device_type="cuda"):
            h_3d = self.encode_points(atoms, coords, nodes, edges)
            h_smiles = self.encode_tokens(raw_tokens, tokenizer)
            h_graph = self.encode_graph(atoms, nodes, edges, gtokenizer)

            try:
                assert h_3d.shape[0] == h_smiles.shape[0]
            except Exception as Ex:
                print(
                    Ex,
                    raw_tokens.shape,
                    augmented_tokens.shape,
                    atoms.shape,
                    coords.shape,
                    h_3d.shape,
                    h_smiles.shape,
                )
                raise Ex

            point_proj_token = self.proj_to_token(h_3d)
            smiles_proj_token = self.proj_to_token(h_smiles)
            graph_proj_token = self.proj_to_token(h_graph)
            noise = (
                (torch.rand((h_3d.shape[0],), device=atoms.device))
                .unsqueeze(-1)
                .repeat(1, point_proj_token.shape[-1])
            )

            proj_token = torch.where(
                noise > p_emb_graph + p_emb_3d,
                graph_proj_token,
                torch.where(noise > p_emb_3d, point_proj_token, smiles_proj_token),
            )

        if torch.isnan(h_3d).any():
            raise Exception("bad 3d")
        if torch.isnan(h_graph).any():
            raise Exception("bad graph")
        if torch.isnan(h_smiles).any():
            raise Exception("bad smiles")

        logits = self.xformer.forward_with_replacement(augmented_tokens, proj_token, tokenizer)
        bad_rows = augmented_tokens.sum(-1) < 1
        return h_3d, h_smiles, h_graph, logits, bad_rows

    def forward_mono(
        self,
        raw_tokens,
        augmented_tokens,
        atoms,
        coords,
        nodes,
        edges,
        tokenizer,
        gtokenizer,
        p_emb_graph=1.0 / 3,
        p_emb_3d=1.0 / 3,
    ):
        """
        for non-distributed
        """
        with autocast(enabled=False, device_type="cuda"):
            h_3d = self.encode_points(atoms, coords, nodes, edges)
            h_smiles = self.encode_tokens(raw_tokens, tokenizer)
            h_graph = self.encode_graph(nodes, edges, gtokenizer)

            assert h_3d.shape[0] == h_smiles.shape[0]
            point_proj_token = self.proj_to_token(h_3d)
            smiles_proj_token = self.proj_to_token(h_smiles)
            graph_proj_token = self.proj_to_token(h_graph)

            noise = (
                (torch.rand((h_3d.shape[0],), device=atoms.device))
                .unsqueeze(-1)
                .repeat(1, point_proj_token.shape[-1])
            )

            proj_token = torch.where(
                noise > p_emb_graph + p_emb_3d,
                graph_proj_token,
                torch.where(noise > p_emb_3d, point_proj_token, smiles_proj_token),
            )

        logits = self.xformer.forward_with_replacement(augmented_tokens, proj_token, tokenizer)
        bad_rows = augmented_tokens.sum(-1) < 1
        return (
            h_3d,
            h_smiles,
            h_graph,
            logits,
            self.loss_fn(h_smiles[:, : self.n_direct_clr], h_3d[:, : self.n_direct_clr], bad_rows),
        )
