import os, sys
import torch
from torch import nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.structure_module import (
    InvariantPointAttention,
    StructureModuleTransition,
    BackboneUpdate,
    AngleResnet
)
from openfold.model.heads import (
    DistogramHead
)
import math
from openfold.utils.rigid_utils import Rigid, Rotation
from openfold.utils.tensor_utils import add, one_hot, dict_multimap
from openfold.model.template import (
    TemplatePairStack,
    TemplatePointwiseAttention
)
from openfold.model.evoformer import (
    EvoformerStack
)
from functools import partial
from utils.funcs import calc_distogram, build_distogram_lower
import plotly.express as px
import plotly
import pandas as pd
import entity.entity_constants as ec

class GaussianFourierProjection(nn.Module):
    """Gaussian Fourier embeddings for noise levels.
    from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32
    """

    def __init__(self, embedding_size=256, scale=1.0):
        super().__init__()
        self.W = nn.Parameter(
            torch.randn(embedding_size // 2) * scale, requires_grad=False
        )

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * math.pi
        emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        return emb


class FeatureEmbedder(nn.Module):
    def __init__(
        self,
        *,
        t_emb_dim,
        seq_dim,
        pair_dim,
        c_z,
        c_m,
        relpos_k,
        distogram_params,
        enable_sc,
        enable_sequence,
        **kwargs
    ):
        super().__init__()
        self.time_projector = GaussianFourierProjection(
            embedding_size=t_emb_dim,
            scale=16.0
        )
        if enable_sequence:
            seq_dim += ec.token_type_num
            self.enable_sequence = True
        else:
            self.enable_sequence = False
        self.linear_seq_z_i = Linear(seq_dim, c_z)
        self.linear_seq_z_j = Linear(seq_dim, c_z)
        self.linear_seq_s = Linear(seq_dim, c_m)
        self.linear_pair_z = Linear(pair_dim, c_z)
        distogram_lower = build_distogram_lower(**distogram_params)
        dist_dim = len(distogram_lower)
        if enable_sc:
            self.linear_sc_dist_z = Linear(dist_dim, c_z)
            self.enable_sc = True
            self.distogram_lowers = distogram_lower
        else:
            self.enable_sc = False
        self.relpos_k = relpos_k
        self.no_bins = 2 * relpos_k + 1
        self.linear_relpos = Linear(self.no_bins, c_z)
        pass

    def relpos(self, ri: torch.Tensor):
        """
        Computes relative positional encodings

        Implements Algorithm 4.

        Args:
            ri:
                "residue_index" features of shape [*, N]
        """
        d = ri[..., None] - ri[..., None, :]
        boundaries = torch.arange(
            start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
        ) 
        reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
        d = d[..., None] - reshaped_bins
        d = torch.abs(d)
        d = torch.argmin(d, dim=-1)
        d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
        d = d.to(ri.dtype)
        return self.linear_relpos(d)

    def forward(
        self,
        *,
        seq_idx,
        token_type,
        t,
        fixed_mask,
        extra_feat,
        pair_feat,
        entity_type,
        sc_cb_pos=None,
    ):
        t_emb = self.time_projector(t)
        if self.enable_sequence:
            seq_feat = torch.concat([
                torch.tile(t_emb, (1,) + seq_idx.shape[1:] + (1,)),
                torch.nn.functional.one_hot(token_type, ec.token_type_num).type(t_emb.dtype),
                extra_feat,
                entity_type,
                fixed_mask[..., None]
            ], dim=-1)
        else:
            seq_feat = torch.concat([
                torch.tile(t_emb, (1,) + seq_idx.shape[1:] + (1,)),
                extra_feat,
                entity_type,
                fixed_mask[..., None]
            ], dim=-1)
        seq_emb_i = self.linear_seq_z_i(seq_feat)
        seq_emb_j = self.linear_seq_z_j(seq_feat)
        pair_emb = self.relpos(seq_idx.type(seq_emb_i.dtype))
        pair_emb += seq_emb_i[..., None, :]
        pair_emb += seq_emb_j[..., None, :, :]
        pair_emb += self.linear_pair_z(pair_feat)
        if self.enable_sc:
            # print(calc_distogram(sc_cb_pos, self.distogram_lowers).shape)
            # print("sc_cb_pos")
            # print(sc_cb_pos)
            pair_emb += self.linear_sc_dist_z(
                calc_distogram(sc_cb_pos, self.distogram_lowers)
            )
        seq_emb = self.linear_seq_s(seq_feat)
        return seq_emb, pair_emb

class InputEmbedder(nn.Module):
    def __init__(
        self,
        *,
        c_t,
        distogram_params,
        **kwargs
    ):
        super().__init__()
        lower = build_distogram_lower(**distogram_params)
        lower_dim = len(lower)
        self.linear_dist_t = Linear(lower_dim, c_t)
        self.lower = lower
    
    def forward(
        self,
        *,
        input_pos,
        mask=None
    ):
        # print("input pos")
        # print(input_pos)
        distogram = calc_distogram(input_pos, self.lower)        
        if mask is not None:
            z_mask = mask[..., :, None] * mask[..., None, :]
            return self.linear_dist_t(distogram * z_mask[..., None])
        else:
            return self.linear_dist_t(distogram)
        

class IPAStructure(nn.Module):
    def __init__(
        self,
        c_s,
        c_z,
        c_ipa,
        c_resnet,
        no_heads_ipa,
        no_qk_points,
        no_v_points,
        dropout_rate,
        no_blocks,
        no_transition_layers,
        no_resnet_blocks,
        trans_scale_factor,
        epsilon,
        inf,
        shared_weight,
        **kwargs,
    ):
        """
        Args:
            c_s:
                Single representation channel dimension
            c_z:
                Pair representation channel dimension
            c_ipa:
                IPA hidden channel dimension
            c_resnet:
                Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
            no_heads_ipa:
                Number of IPA heads
            no_qk_points:
                Number of query/key points to generate during IPA
            no_v_points:
                Number of value points to generate during IPA
            dropout_rate:
                Dropout rate used throughout the layer
            no_blocks:
                Number of structure module blocks
            no_transition_layers:
                Number of layers in the single representation transition
                (Alg. 23 lines 8-9)
            no_resnet_blocks:
                Number of blocks in the angle resnet
            no_angles:
                Number of angles to generate in the angle resnet
            trans_scale_factor:
                Scale of single representation transition hidden dimension
            epsilon:
                Small number used in angle resnet normalization
            inf:
                Large number used for attention masking
        """
        super().__init__()

        self.c_s = c_s
        self.c_z = c_z
        self.c_ipa = c_ipa
        self.c_resnet = c_resnet
        self.no_heads_ipa = no_heads_ipa
        self.no_qk_points = no_qk_points
        self.no_v_points = no_v_points
        self.dropout_rate = dropout_rate
        self.no_blocks = no_blocks
        self.no_transition_layers = no_transition_layers
        self.no_resnet_blocks = no_resnet_blocks
        self.trans_scale_factor = trans_scale_factor
        self.epsilon = epsilon
        self.inf = inf
        self.shared_weight = shared_weight

        # Buffers to be lazily initialized later
        # self.default_frames
        # self.group_idx
        # self.atom_mask
        # self.lit_positions

        self.layer_norm_s = LayerNorm(self.c_s)
        self.layer_norm_z = LayerNorm(self.c_z)

        self.linear_in = Linear(self.c_s, self.c_s)

        if self.shared_weight:
            self.ipa = InvariantPointAttention(
                self.c_s,
                self.c_z,
                self.c_ipa,
                self.no_heads_ipa,
                self.no_qk_points,
                self.no_v_points,
                inf=self.inf,
                eps=self.epsilon,
            )
            self.ipa_dropout = nn.Dropout(self.dropout_rate)
            self.layer_norm_ipa = LayerNorm(self.c_s)
            self.transition = StructureModuleTransition(
                self.c_s,
                self.no_transition_layers,
                self.dropout_rate
            )
            self.bb_update = BackboneUpdate(self.c_s)
        else:
            self.ipa = nn.ModuleList([
                InvariantPointAttention(
                    self.c_s,
                    self.c_z,
                    self.c_ipa,
                    self.no_heads_ipa,
                    self.no_qk_points,
                    self.no_v_points,
                    inf=self.inf,
                    eps=self.epsilon,
                ) for _ in range(no_blocks)
            ])
            self.ipa_dropout = nn.ModuleList([
                nn.Dropout(self.dropout_rate) for _ in range(no_blocks)
            ])
            self.layer_norm_ipa = nn.ModuleList([
                LayerNorm(self.c_s) for _ in range(no_blocks)
            ])
            self.transition = nn.ModuleList([
                StructureModuleTransition(
                    self.c_s,
                    self.no_transition_layers,
                    self.dropout_rate,
                ) for _ in range(no_blocks)
            ])
            self.bb_update = nn.ModuleList([
                BackboneUpdate(self.c_s) for _ in range(no_blocks)
            ])

        # self.angle_resnet = AngleResnet(
        #     c_in=self.c_s,
        #     c_hidden=self.c_resnet,
        #     no_blocks=self.no_resnet_blocks,
        #     no_angles=1,
        #     epsilon=self.epsilon,
        # )

    def forward(
        self,
        s,
        z,
        mask=None,
        inplace_safe=False,
        _offload_inference=False,
    ):
        """
        Args:
            evoformer_output_dict:
                Dictionary containing:
                    "single":
                        [*, N_res, C_s] single representation
                    "pair":
                        [*, N_res, N_res, C_z] pair representation
            aatype:
                [*, N_res] amino acid indices
            mask:
                Optional [*, N_res] sequence mask
        Returns:
            A dictionary of outputs
        """        
        if mask is None:
            # [*, N]
            mask = s.new_ones(s.shape[:-1])

        # [*, N, C_s]
        s = self.layer_norm_s(s)

        # [*, N, N, C_z]
        z = self.layer_norm_z(z)

        z_reference_list = None

        # [*, N, C_s]
        s_initial = s
        s = self.linear_in(s)

        # [*, N]
        rigids = Rigid.identity(
            s.shape[:-1], 
            s.dtype, 
            s.device, 
            self.training,
            fmt="quat",
        )
        outputs = []
        for i in range(self.no_blocks):
            # [*, N, C_s]
            if self.shared_weight:
                s = s + self.ipa(
                    s, 
                    z, 
                    rigids, 
                    mask, 
                    inplace_safe=inplace_safe,
                    _offload_inference=_offload_inference, 
                    _z_reference_list=z_reference_list
                )    
                s = self.ipa_dropout(s)
                s = self.layer_norm_ipa(s)
                s = self.transition(s)
                # [*, N]
                upd_vec = self.bb_update(s)
                # print(upd_vec.max(dim=-2).values - upd_vec.min(dim=-2).values)
                rigids = rigids.compose_q_update_vec(upd_vec)
                # print(rigids.to_tensor_7().max(dim=-2).values - rigids.to_tensor_7().min(dim=-2).values)
            else:
                s = s + self.ipa[i](
                    s, 
                    z, 
                    rigids, 
                    mask, 
                    inplace_safe=inplace_safe,
                    _offload_inference=_offload_inference, 
                    _z_reference_list=z_reference_list
                )    
                s = self.ipa_dropout[i](s)
                s = self.layer_norm_ipa[i](s)
                s = self.transition[i](s)
                # [*, N]
                rigids = rigids.compose_q_update_vec(self.bb_update[i](s))
           

            backb_to_global = Rigid(
                Rotation(
                    rot_mats=rigids.get_rots().get_rot_mats(), 
                    quats=None
                ),
                rigids.get_trans(),
            )

            backb_to_global = backb_to_global.scale_translation(
                self.trans_scale_factor
            )

            
            scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
            preds = {
                "frames": scaled_rigids.to_tensor_7(),
                "states": s,
            }

            outputs.append(preds)

            rigids = rigids.stop_rot_gradient()

        # _, angles = self.angle_resnet(s, s_initial)

        del z, z_reference_list

        outputs = dict_multimap(torch.stack, outputs)
        return {
            # [..., N] rigid objects
            "final_rigids": scaled_rigids.to_tensor_7(),
            # [..., N, 2] psi angle (only one angle, last dim: sin, cos)
            "outputs": outputs
        }


class Denoise(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.feature_embedder = FeatureEmbedder(
            **cfg.model.feature_embedder
        )
        self.input_embedder = InputEmbedder(
            **cfg.model.input_embedder
        )
        if cfg.model.enable_ligand_hint:
            self.enable_ligand_hint = True
            self.ligand_hint_embedder = InputEmbedder(
                **cfg.model.input_embedder
            )
        else:
            self.enable_ligand_hint = False
        self.input_pair_stack = TemplatePairStack(
            **cfg.model.input_pair_stack
        )
        self.input_pointwise_attn = TemplatePointwiseAttention(
            **cfg.model.input_pointwise_attn
        )
        self.core_stack_sel = cfg.model.core_stack_sel
        if cfg.model.core_stack_sel == "evoformer":
            self.transformer_stack = EvoformerStack(
                **cfg.model.transformer_stack
            )
        else:
            self.z_pair_stack = TemplatePairStack(
                **cfg.model.z_pair_stack
            )
        self.ipa_structure = IPAStructure(
            **cfg.model.ipa_structure
        )
        self.distogram_head = DistogramHead(
            **cfg.model.distogram_head
        )
        self.enable_input_scale = True if cfg.enable_scale == "input" else False
        self.sigma_data = cfg.sigma_data
        self.use_deepspeed_evo_attention = cfg.use_deepspeed_evo_attention

    def scale_factor(self, t):
        return self.sigma_data / torch.sqrt(t**2 + ((1-t)*self.sigma_data)**2)
    
    def forward(self, batch):
        s, z = self.feature_embedder(
            seq_idx=batch["token_index"],
            token_type=batch["token_type"],
            t=batch["t"],
            fixed_mask=batch["fixed_mask"],
            sc_cb_pos=batch["sc_cb_pos"],
            extra_feat=batch["extra_feat"],
            pair_feat=batch["pair_feat"],
            entity_type=batch["entity_type"]
        )
        pair_mask = batch["seq_mask"][..., :, None] * batch["seq_mask"][..., None, :]

        input_pos = batch["input_pos"]
        if self.enable_input_scale:
            input_pos = input_pos * self.scale_factor(batch["t"])

        input_embedding = self.input_embedder(
            input_pos=input_pos
        )

        if self.enable_ligand_hint:
            input_embedding = input_embedding + self.ligand_hint_embedder(
                input_pos=batch["hint_pos"],
                mask=batch["hint_mask"]
            )

        t = self.input_pair_stack(
            input_embedding,
            mask=pair_mask,
            use_deepspeed_evo_attention=self.use_deepspeed_evo_attention,
            # use_lma=True,
            chunk_size=None
        )
        z = z + self.input_pointwise_attn(t, z, chunk_size=None)
        if self.core_stack_sel == "evoformer":
            _, z, s = self.transformer_stack(
                s[..., None, :, :],
                z,
                msa_mask=None,
                pair_mask=pair_mask,
                use_deepspeed_evo_attention=self.use_deepspeed_evo_attention,
                # use_lma=True,
                chunk_size=None
            )
        else:
            z = self.z_pair_stack(
                z,
                mask=pair_mask,
                use_deepspeed_evo_attention=self.use_deepspeed_evo_attention,
                # use_lma=True,
                chunk_size=None
            )
        del t, pair_mask
        structure_output = self.ipa_structure(s, z)
        logits = self.distogram_head(z)

        # for i in range(4):
        #     fig = px.scatter_3d(
        #         pd.DataFrame(
        #             structure_output["outputs"]["frames"].detach().cpu().numpy()[i, 0, :, 4:],
        #             columns=["x", "y", "z"]
        #         ),
        #         x="x",
        #         y="y",
        #         z="z"
        #     )
        #     plotly.io.write_html(fig, f"output/outputs_{batch['t'][0].item():.2}_{i}.html")
        
        return {
            **structure_output,
            "psi": torch.ones([*batch["seq_mask"].shape, 2]),
            "distogram_logits": logits
        }