# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
import pickle
from typing import Callable, Optional, Sequence, Tuple, List
import numpy as np
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode

class GLU(nn.Module):
    def __init__(self,hidden_size):
        super().__init__()
        self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.act1 = nn.GELU()
        self.act2 = nn.functional.silu
        self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False)
        self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False)
        self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False)

    def forward(self,x):
        x = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x)
        x = self.dense_4h_to_h(x)
        return x
def swiglu(x):
    x = torch.chunk(x, 2, dim=-1)
    return nn.functional.silu(x[0]) * x[1]

class GLU_new(nn.Module):
    def __init__(self,hidden_size, dropout=0.1):
        super().__init__()
        intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64
        intermediate_size = 1280

        self.act = swiglu
        self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False)
        self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self,x):
        x = self.dense_h_to_4h(x)
        x = self.act(x)
        x = self.dense_4h_to_h(x)
        x = self.dropout(x)
        return x


n_queries = 32
def get_abs_pos(abs_pos, tgt_size):
    # abs_pos: L, C
    # tgt_size: M
    # return: M, C
    src_size = int(math.sqrt(abs_pos.size(0)))
    tgt_size = int(math.sqrt(tgt_size))
    dtype = abs_pos.dtype

    if src_size != tgt_size:
        return F.interpolate(
            abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
            size=(tgt_size, tgt_size),
            mode="bicubic",
            align_corners=False,
        ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
    else:
        return abs_pos

from einops import rearrange, repeat

def get_1d_sincos_pos_embed(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

class Resampler(nn.Module):
    def __init__(
        self,
        kv_dim,
        embed_dim,
        num_heads=8,
        n_queries=64,
        max_seqlen=1024,
        perceiver_resampler_positional_emb=True,
        use_GLU=False,
        bos_init=False,
        dropout=0.0
    ):
        super().__init__()
        self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb

        if self.perceiver_resampler_positional_emb:
            assert n_queries <= max_seqlen
            self.stride = max_seqlen // n_queries
            # self.nan_emb = nn.Parameter(torch.randn(1, kv_dim))
            # nn.init.trunc_normal_(self.nan_emb, std=.02)
            pos = np.arange(max_seqlen, dtype=np.float32)
            self.register_buffer(
                "pos_embed",
                torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float()
            )
        self.latents = nn.Parameter(torch.randn(n_queries, embed_dim))
        if bos_init:
            self.latents.load('')
        else:
            nn.init.trunc_normal_(self.latents, std=1e-3)

        self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
        self.ln_q = nn.LayerNorm(embed_dim)
        self.ln_kv = nn.LayerNorm(embed_dim)
        self.ln_post = nn.LayerNorm(embed_dim)
        if use_GLU:
            print('GLU *********************************')
            self.proj = GLU_new(embed_dim, dropout=dropout)
        else:
            self.proj = nn.Linear(embed_dim, embed_dim, bias=False)

        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=1e-3)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, struc_x):
        """
        Args:
            x (torch.Tensor): protein structure features
                shape (B, L, C)
        Returns:
            shape (B, n, C) where n is self.num_latents
        """
        x = struc_x["encoder_out"]
        mask = struc_x["encoder_padding_mask"]


        nan_mask = torch.isnan(x)
        if nan_mask.any():
            x = x.masked_fill(nan_mask, 0.0)
            # nan_mask = nan_mask.sum(dim=-1).bool()
            # x[nan_mask] += self.nan_emb

        x = self.kv_proj(x)
        x = self.ln_kv(x)

        b, seqlen = x.shape[:2]

        latents = self.ln_q(self.latents)
        if self.perceiver_resampler_positional_emb:
            # TODO: interpolate
            latents = latents + self.pos_embed[::self.stride].contiguous().to(latents.device)
            pos_emb = self.pos_embed[:seqlen].unsqueeze(0).to(latents.device)
            x = x + pos_emb.contiguous()
        
        # blocks
        latents = repeat(latents, "n d -> b n d", b=b)
        out = self.attn(latents, x, x, key_padding_mask=~mask)[0]

        out = self.ln_post(out)
        out = self.proj(out)

        return out

class mlp(nn.Module):
    def __init__(self, width, output_dim, **kwargs):
        super().__init__(**kwargs)

        self.mlp = nn.Sequential(
            nn.Linear(width, output_dim),
            nn.LayerNorm(output_dim),
            nn.GELU(),
            nn.Linear(output_dim, output_dim)
        )

    def forward(self, struc_x):
        x = struc_x["encoder_out"]
        mask = struc_x["encoder_padding_mask"]
        return self.mlp(x), mask

class StructureTransformer(nn.Module):

    def __init__(
            self,
            width: int = 640,
            n_queries: int = 32,
            output_dim: int = 4096,
            embedding_keys=set(["mpnn_emb"]),
            max_seqlen: int=1024,
            num_heads: int=8,
            structure_emb_path_prefix='structure_emb',
            projector='mlp',
            **kwargs
    ):
        super().__init__()

        self.structure_emb_path_prefix = structure_emb_path_prefix
        # self.transformer = None # replace None with a pretrained strucure encoder
        self.embedding_keys = embedding_keys
        self.max_seqlen = max_seqlen
        self.width = width
        self.n_queries = n_queries
        if projector == 'mlp':
            self.attn_pool = mlp(
                width=width,
                output_dim=output_dim,
                **kwargs
                )
        else:
            self.attn_pool = Resampler(
                embed_dim=output_dim,
                kv_dim=width,
                n_queries=n_queries,
                max_seqlen=max_seqlen,
                num_heads=num_heads,
                **kwargs
            )

    def prepare_structure(self, sample):
        emb_pad = torch.zeros((self.max_seqlen, self.width))
        emb_mask = torch.zeros((self.max_seqlen), dtype=bool)
        
        
        ### domians ###
        emb = []
        for ek in self.embedding_keys:
            if ek in sample:
                if isinstance( sample[ek], List):
                    emb.append(torch.cat(sample[ek]))
                else:
                    emb.append(sample[ek].squeeze())
        # emb = [sample[ek] for ek in self.embedding_keys if ek in sample]
        emb = torch.cat(emb, dim=-1)
        
        emb_pad[:len(emb)] = emb
        emb_mask[:len(emb)] = 1
        return emb_pad, emb_mask

    def forward(self, x):

        # x = self.transformer(x)
        x = self.attn_pool(x)

        return x

    def encode(self, input_ids: List[torch.Tensor]):
        structure_embs = []
        structure_mask = []

        for structure_path in input_ids:
            
            if structure_path[0] == 1: # bos token for bypassing DPO trainer
                structure_path[0] = 0
            if structure_path[0] == 0: # left padding
                structure_path = structure_path[structure_path > 0]
                
            path_length = (structure_path>32).sum() # structure path should greater than 32 in ascii
            structure_path = [chr(s) for s in structure_path[:path_length].int().tolist() if s > 0]

            structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path))
            
            if not os.path.exists(structure_path):
                print(structure_path)
                print('no structure found')
                return None
                
            with open(structure_path, 'rb') as f:
                structure, struc_mask = self.prepare_structure(pickle.load(f))
                

            structure_embs.append(structure)
            structure_mask.append(struc_mask)

        input_ids[input_ids > 32] = 1 # change ascii code back to <|bos|>
        structure_embs = torch.stack(structure_embs, dim=0).to(
            device=next(self.attn_pool.parameters()).device, 
            dtype=next(self.attn_pool.parameters()).dtype)
        structure_mask = torch.stack(structure_mask, dim=0).to(
            device=next(self.attn_pool.parameters()).device)

        return self({
                'encoder_out': structure_embs,
                'encoder_padding_mask': structure_mask
            }), input_ids
