import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class FF(nn.Module):
    # def __init__(self, in_dim, hidden_dim, out_dim, dim, n_frames):
    def __init__(self, n_frames, n_archetypes, hidden_dim, out_dim, dim):
        super(FF, self).__init__()

        self.n_frames = n_frames

        time_dim = dim
        
        # SPE+MLP for start age, which captures global age data
        self.start_age_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # SPE+MLP for age deltas, which capture relative positioning in time
        self.delta_age_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # MLP for feature embedding
        self.feature_mlp = nn.Sequential(
            nn.Linear(n_archetypes, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # MLP for history dynamics
        self.dynamics_mlp = nn.Sequential(
            nn.Linear(n_frames*(hidden_dim+2*time_dim), hidden_dim*2),
            nn.GELU(),
            nn.Linear(hidden_dim*2, hidden_dim),
        )

        # MLP for prediction
        self.prediction_mlp = nn.Sequential(
            nn.Linear(hidden_dim+time_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, baselines, start_age, age_deltas):
        # start_age embedding, init size (B, 1)
        sa = self.start_age_mlp(start_age) # B, 1, self.time_dim
        sa = sa.repeat(1,self.n_frames,1) # B, self.n_frames, self.time_dim

        # delta_age embedding, init size (B, 16)
        da = self.delta_age_mlp(age_deltas) # B, 1, (self.n_frames+1), self.time_dim
        da = da.view(da.size(0), da.size(2), da.size(3)) # B, (self.n_frames+1), self.time_dim

        # feature embedding, init size (B, self.n_frames, self.n_archetypes)
        feat = self.feature_mlp(baselines) # B, self.n_frames, self.hidden_dim
    
        # concatenate history and time embeddings: [feature, start_age, delta_age]
        history_time = torch.cat((feat, sa, da[:,:-1,:]), dim=-1) # B, self.n_frames, self.hidden_dim+self.time_dim+self.time_dim
        history_time = history_time.view(history_time.size(0), -1) # B, self.n_frames*(self.hidden_dim+self.time_dim+self.time_dim)

        # apply dynamics block
        history_encoding = self.dynamics_mlp(history_time) # B, self.hidden_dim

        # concatenate history_encoding with future delta embedding
        history_future = torch.cat((history_encoding, da[:,-1,:]), dim=-1) # B, self.hidden_dim+self.time_dim

        # apply prediction block, coefficients should sum to 1 and be positive
        out = F.softmax(self.prediction_mlp(history_future), dim=1)

        return out

class FF_MoCap(nn.Module):
    def __init__(self, n_frames, n_horizon, n_archetypes, hidden_dim, out_dim):
        super(FF_MoCap, self).__init__()

        self.n_frames = n_frames
        self.n_horizon = n_horizon
        self.n_archetypes = n_archetypes
        in_dim = n_frames*n_archetypes

        # MLP for prediction
        self.prediction_mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, history):
        # apply prediction block, coefficients should sum to 1 and be positive
        out = self.prediction_mlp(history)
        out = out.view(-1, self.n_horizon, self.n_archetypes)
        out = F.softmax(out, dim=-1)
        return out