import torch
import copy
from torch import nn
import math
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def sinusoidal_embedding(timesteps, dim):
    half_dim = dim // 2
    exponent = -math.log(10000) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32)
    exponent = exponent / (half_dim - 1.0)

    emb = torch.exp(exponent).to(device=timesteps.device)
    emb = timesteps[:, None].float() * emb[None, :]

    return torch.cat([emb.sin(), emb.cos()], dim=-1)

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn_mat = self.attend(dots)
        attn = self.dropout(attn_mat)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, temb_dim, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, dim, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

        self.time_emb_proj = nn.Sequential(
            nn.SiLU(), 
            torch.nn.Linear(temb_dim, dim)
        )

    def forward(self, x, temb):
        for attn, ff1, ff2 in self.layers:
            v = attn(x)
            x = v + x 
            x = ff1(x) + x
            x = self.time_emb_proj(temb) + x
            x = ff2(x) + x
            
        return self.norm(x)

class SDT(nn.Module):
    def __init__(
        self, 
        time_dim, 
        drug_cond_size, 
        prot_cond_size, 
        pk_y_dim, 
        dti_y_dim,
        patch_size, 
        dim, 
        depth, 
        heads, 
        mlp_dim,
        drug_dki_size = 768,
        target_dki_size = 256,
        pool = 'cls', 
        dim_head = 64, 
        dropout = 0., 
        emb_dropout = 0.
    ):
        super().__init__()

        self.timestep_input_dim = time_dim
        self.time_embed_dim = self.timestep_input_dim * 4

        self.time_embedding = nn.Sequential(
        nn.Linear(self.timestep_input_dim, self.time_embed_dim), 
        nn.SiLU(),
        nn.Linear(self.time_embed_dim, self.time_embed_dim))

        num_drug_patches = (drug_cond_size // patch_size)
        drug_patch_dim = patch_size

        num_prot_patches = (prot_cond_size // patch_size)
        prot_patch_dim = patch_size

        self.to_pk_drug_patch_embedding = nn.Sequential(
            Rearrange('b (h p) -> b h p', p = patch_size),
            nn.LayerNorm(drug_patch_dim),
            nn.Linear(drug_patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pk_to_dti_drug_emb = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

        self.dti_to_pk_drug_emb = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

        self.to_dti_drug_patch_embedding = nn.Sequential(
            Rearrange('b (h p) -> b h p', p = patch_size),
            nn.LayerNorm(drug_patch_dim),
            nn.Linear(drug_patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.to_drug_pk_dki_embedding = nn.Sequential(
            nn.Linear(drug_dki_size, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

        self.to_drug_dti_dki_embedding = nn.Sequential(
            nn.Linear(drug_dki_size, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

        self.to_prot_patch_embedding = nn.Sequential(
            Rearrange('b (h p) -> b h p', p = patch_size),
            nn.LayerNorm(prot_patch_dim),
            nn.Linear(prot_patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.to_drug_target_embedding = nn.Sequential(
            nn.Linear(target_dki_size, dim),
            nn.SiLU(),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
        )

        self.to_pk_y = nn.Sequential(
            nn.Linear(pk_y_dim, dim),
            nn.SiLU(),
            nn.LayerNorm(dim)
        )

        self.to_dti_y = nn.Sequential(
            nn.Linear(dti_y_dim, dim),
            nn.SiLU(),
            nn.LayerNorm(dim)
        )

        # y_in + pk_drug_cond + dti_drug_cond + drug_dki + prot_cond + prot_dki
        self.pk_pos_embedding = nn.Parameter(torch.randn(1, 2*num_drug_patches + num_prot_patches + 3, dim))
        self.dti_pos_embedding = nn.Parameter(torch.randn(1, 2*num_drug_patches + num_prot_patches + 3, dim))

        #self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.pk_transformer = Transformer(self.time_embed_dim, dim, depth, heads, dim_head, mlp_dim, dropout)
        self.dti_transformer = Transformer(self.time_embed_dim, dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool

        # Regression
        self.pk_mlp_head = nn.Linear(dim, pk_y_dim)
        self.dti_mlp_head = nn.Linear(dim, dti_y_dim)

    def forward(
        self, 
        timesteps,
        drug_cond,
        pk=None,
        dti=None,
        prot_cond=None,
        drug_dki=None,
        target_dki=None,
    ):
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps],
                                     dtype=torch.long,
                                     device=drug_cond.device)
        timesteps = torch.flatten(timesteps)
        timesteps = timesteps.broadcast_to(drug_cond.shape[0])

        t_emb = sinusoidal_embedding(timesteps, self.timestep_input_dim)
        t_emb = self.time_embedding(t_emb).unsqueeze(1)

        drug_pk_x = self.to_pk_drug_patch_embedding(drug_cond)
        drug_dti_x = self.to_dti_drug_patch_embedding(drug_cond)

        pred_dti = None
        if prot_cond is not None and dti is not None:
            
            drug_pk_x = self.pk_to_dti_drug_emb(drug_pk_x.clone().detach())
            drug_x = torch.cat((drug_dti_x, drug_pk_x), dim=1)
            if drug_dki is not None:
                drug_embedding = self.to_drug_dti_dki_embedding(drug_dki).unsqueeze(1)
                drug_x = torch.cat((drug_embedding, drug_x), dim=1)
            
            prot_x = self.to_prot_patch_embedding(prot_cond)
            if target_dki is not None:
                target_embedding = self.to_drug_target_embedding(target_dki).unsqueeze(1)
                prot_x = torch.cat((target_embedding, prot_x), dim=1)

            # Combine drug and dti
            dti_y = self.to_dti_y(dti).unsqueeze(1)
            dti_x = torch.cat((dti_y, drug_x, prot_x), dim=1)

            # pos_embeddings
            _, n, _ = dti_x.shape
            dti_x += self.dti_pos_embedding[:,:n]
            dti_x = self.dti_transformer(dti_x, t_emb)
            dti_x = dti_x.mean(dim = 1) if self.pool == 'mean' else dti_x[:, 0]
            dti_x = self.dti_mlp_head(dti_x)

            pred_dti = dti_x
        
        pred_pk = None
        if drug_cond is not None and pk is not None:
            
            # drug_dti_x = self.dti_to_pk_drug_emb(drug_dti_x.clone().detach())
            # drug_x = torch.cat((drug_pk_x, drug_dti_x), dim=1)
            drug_x = drug_pk_x
            if drug_dki is not None:
                drug_embedding = self.to_drug_pk_dki_embedding(drug_dki).unsqueeze(1)
                drug_x = torch.cat((drug_embedding, drug_x), dim=1)

            # concatenate
            pk_y = self.to_pk_y(pk).unsqueeze(1)
            drug_x = torch.cat((pk_y, drug_x), dim=1)

            # add pos embedding
            _, n, _ = drug_x.shape
            drug_x += self.pk_pos_embedding[:, :n]
            pk_x = self.pk_transformer(drug_x, t_emb)
            pk_x = pk_x.mean(dim = 1) if self.pool == 'mean' else pk_x[:, 0]
            pk_x = self.pk_mlp_head(pk_x)

            pred_pk = pk_x
        
        return pred_pk, pred_dti

class EMA:
    def __init__(self, model, base_gamma, total_steps):
        super().__init__()
        self.online_model = model

        self.ema_model = copy.deepcopy(self.online_model)
        self.ema_model.requires_grad_(False)

        self.base_gamma = base_gamma
        self.total_steps = total_steps

    def update_params(self, gamma):
        with torch.no_grad():
            valid_types = [torch.float, torch.float16]
            for o_param, t_param in self._get_params():
                if o_param.dtype in valid_types and t_param.dtype in valid_types:
                    t_param.data.lerp_(o_param.data, 1. - gamma)

            for o_buffer, t_buffer in self._get_buffers():
                if o_buffer.dtype in valid_types and t_buffer.dtype in valid_types:
                    t_buffer.data.lerp_(o_buffer.data, 1. - gamma)

    def _get_params(self):
        return zip(self.online_model.parameters(),
                   self.ema_model.parameters())

    def _get_buffers(self):
        return zip(self.online_model.buffers(),
                   self.ema_model.buffers())
    
    # cosine EMA schedule (increase from base_gamma to 1)
    # k -> current training step, K -> maximum number of training steps
    def update_gamma(self, current_step):
        k = torch.tensor(current_step, dtype=torch.float32)
        K = torch.tensor(self.total_steps, dtype=torch.float32)

        tau = 1 - (1 - self.base_gamma) * (torch.cos(torch.pi * k / K) + 1) / 2
        return tau.item()