# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from einops import rearrange

class ResBlock3D(nn.Module):
    def __init__(self, channels, kernel_size=(3,3,3), padding=(1,1,1)):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(channels, channels, kernel_size, padding=padding),
            nn.GroupNorm(num_groups=8, num_channels=channels),
            nn.SiLU(), 
            nn.Conv3d(channels, channels, kernel_size, padding=padding),
            nn.GroupNorm(num_groups=8, num_channels=channels)
        )
        self.relu = nn.SiLU()

    def forward(self, x):
        return self.relu(x + self.net(x))


class DeepConvIDM(nn.Module):
    def __init__(self, structure_channels=32, action_dim=256, base_channels=64):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv3d(structure_channels, base_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.GroupNorm(num_groups=8, num_channels=base_channels),
            nn.SiLU()
        )
        
        self.stage1 = nn.Sequential(
            ResBlock3D(base_channels),
            ResBlock3D(base_channels)
        )
        
        self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=(1,2,2), padding=1)
        self.stage2 = nn.Sequential(
            ResBlock3D(base_channels*2),
            ResBlock3D(base_channels*2) 
        )
        
        self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=(1,2,2), padding=1)
        self.stage3 = nn.Sequential(
            ResBlock3D(base_channels*4),
            ResBlock3D(base_channels*4)
        )
        
        self.down3 = nn.Conv3d(base_channels*4, base_channels*8, kernel_size=3, stride=(1,2,2), padding=1)
        self.stage4 = nn.Sequential(
            ResBlock3D(base_channels*8),
        )
        
        self.down4 = nn.Conv3d(base_channels*8, base_channels*8, kernel_size=3, stride=(1,2,2), padding=1)
        
        self.head = nn.Sequential(
            nn.Linear(base_channels*8, 1024),
            nn.SiLU(),
            nn.Dropout(0.1),
            nn.Linear(1024, action_dim)
        )

    def forward(self, structure_prev, structure_next):
        x = rearrange(structure_next - structure_prev, 'b t h w c -> b c t h w')
        x = self.stem(x)  # [B, 64, T, 16, 16]
        
        x = self.stage1(x)  # [B, 64, T, 16, 16]
        x = self.down1(x)   # [B, 128, T, 8, 8]
        
        x = self.stage2(x)  # [B, 128, T, 8, 8]
        x = self.down2(x)   # [B, 256, T, 4, 4]
        
        x = self.stage3(x)  # [B, 256, T, 4, 4]
        x = self.down3(x)   # [B, 512, T, 2, 2]
        
        x = self.stage4(x)  # [B, 512, T, 2, 2]
        x = self.down4(x)   # [B, 512, T, 1, 1]
        
        x = rearrange(x, 'b c t h w -> b t (h w c)')
        
        return self.head(x)