import torch.nn as nn
import torch.nn.functional as F
from timm.layers import Mlp
import torch
from torch.utils.checkpoint import checkpoint

class MlpResNetBlock(nn.Module):

    def __init__(self, hidden_dim:int, ac_fn=F.gelu, use_layernorm=False, dropout_rate=0.1):
        super(MlpResNetBlock, self).__init__()
        self.hidden_dim = hidden_dim
        self.use_layernorm = use_layernorm
        self.dropout = nn.Dropout(dropout_rate)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.dense1 = nn.Linear(hidden_dim, hidden_dim * 4)
        self.ac_fn = ac_fn
        self.dense2 = nn.Linear(hidden_dim * 4, hidden_dim)
            
    def forward(self, x):
        out = self.dropout(x)
        out = self.norm1(out)
        out = self.dense1(out)
        out = self.ac_fn(out)
        out = self.dense2(out)
        out = x + out
        return out

class MlpResNet(nn.Module):

    def __init__(self, num_blocks:int, input_dim:int, hidden_dim:int, output_size:int, 
                 ac_fn=F.gelu, use_layernorm=True, dropout_rate=0.1):
        super(MlpResNet, self).__init__()
        
        self.dense1 = nn.Linear(input_dim, hidden_dim)
        self.dense2 = nn.Linear(hidden_dim, output_size)
        self.ac_fn = ac_fn
        self.mlp_res_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.mlp_res_blocks.append(MlpResNetBlock(hidden_dim, ac_fn, use_layernorm, dropout_rate))
            
    def forward(self, x):
        out = self.dense1(x)
        for mlp_res_block in self.mlp_res_blocks:
            out = mlp_res_block(out)
        out = self.ac_fn(out)
        out = self.dense2(out)
        return out

class LearnedPosEmb(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.output_size = output_size
        self.kernel = nn.Parameter(torch.randn(output_size // 2, input_size) * 0.2)

    def forward(self, x):
        f = 2 * torch.pi * x @ self.kernel.T
        f = torch.cat([f.cos(), f.sin()], axis=-1)
        return f

class MpScaleMlpResNet(nn.Module):
    def __init__(self, num_blocks: int, input_dim: int, hidden_dim: int, output_size: int,
                 ac_fn=F.gelu, use_layernorm=True, dropout_rate=0.1,
                 time_dim=32, time_hidden_dim=256):
        super(MpScaleMlpResNet, self).__init__()

        self.dense1 = nn.Linear(input_dim, hidden_dim)
        self.dense2 = nn.Linear(hidden_dim, output_size)
        self.ac_fn = ac_fn
        self.mlp_res_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.mlp_res_blocks.append(MlpResNetBlock(hidden_dim, ac_fn, use_layernorm, dropout_rate))

        self.time_process = LearnedPosEmb(1, time_dim)
        self.time_encoder = Mlp(time_dim, time_hidden_dim, hidden_dim, norm_layer=nn.LayerNorm)


    def forward(self, sample,timestep,global_cond,r):
        time_embedding = self.time_process(timestep.view(-1, 1))
        timestep_embed = self.time_encoder(time_embedding)

        rs_embedding = self.time_process(r.view(-1, 1))
        rs_embed = self.time_encoder(rs_embedding)

        timestep_embed = timestep_embed + rs_embed

        B, N, D = sample.shape  # [B, N, 32]
        sample = sample.reshape(B * N, D) # [B*N, hidden_dim]
        global_cond = global_cond.reshape(B * N, -1) # [B*N, hidden_dim]
        timestep_embed = timestep_embed.repeat_interleave(N, dim=0)  # [B*N, hidden_dim]

        x = torch.cat([global_cond, timestep_embed, sample], dim=-1)

        out = self.dense1(x)
        for mlp_res_block in self.mlp_res_blocks:
            out = mlp_res_block(out)
        out = self.ac_fn(out)
        out = self.dense2(out)

        out = out.view(B, N, D)
        return out

