import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import ConvSC

def stride_generator(N, reverse=False):
    strides = [1, 2]*10
    return list(reversed(strides[:N])) if reverse else strides[:N]

class Encoder(nn.Module):
    def __init__(self, C_in, C_hid, N_S):
        super().__init__()
        strides = stride_generator(N_S)
        self.enc = nn.Sequential(
            ConvSC(C_in, C_hid, stride=strides[0]),
            *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
        )

    def forward(self, x):
        enc1 = self.enc[0](x)
        latent = enc1
        for layer in self.enc[1:]:
            latent = layer(latent)
        return latent, enc1

class Decoder(nn.Module):
    def __init__(self, C_hid, C_out, N_S):
        super().__init__()
        strides = stride_generator(N_S, reverse=True)
        self.dec = nn.Sequential(
            *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
            ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
        )
        self.readout = nn.Conv2d(C_hid, C_out, 1)

    def forward(self, hid, enc1):
        for layer in self.dec[:-1]:
            hid = layer(hid)
        Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
        return self.readout(Y)

class TitansVideoPrediction(nn.Module):
    def __init__(self, in_channels, embed_dim, memory_size, H, W):
        super().__init__()
        self.embed_dim = embed_dim
        self.H, self.W = H, W
        self.in_channels = in_channels

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, embed_dim, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(embed_dim, embed_dim, 3, padding=1)
        )
        self.neural_memory = NeuralMemory(embed_dim, memory_size)
        self.persistent_memory = PersistentMemory(embed_dim)
        self.core = TitansCore(embed_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(embed_dim, in_channels, 3, padding=1)
        )

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B, T*C, H, W)

        x = self.encoder(x)
        x = x.view(B, self.embed_dim, -1).transpose(1, 2)

        mem = self.neural_memory(x) + self.persistent_memory(x)
        x = self.core(x, mem)

        x = x.transpose(1, 2).view(B, self.embed_dim, H, W)
        x = self.decoder(x)
        return x.view(B, T, C, H, W)

class NeuralMemory(nn.Module):
    def __init__(self, embed_dim, memory_size):
        super().__init__()
        self.register_buffer('memory', torch.zeros(memory_size, embed_dim))
        self.momentum, self.decay = 0.9, 0.01

    def forward(self, x):
        surprise = torch.norm(x - self.memory.mean(dim=0), dim=-1, keepdim=True)
        weighted_x = torch.sigmoid(surprise) * x
        batch_memory = weighted_x.mean(dim=(0, 1))
        new_memory = self.momentum * self.memory.mean(dim=0) + (1 - self.momentum) * batch_memory
        self.memory = (1 - self.decay) * self.memory + self.decay * new_memory
        return self.memory.mean(dim=0).view(1, 1, -1).expand(x.size())

class PersistentMemory(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.memory = nn.Parameter(torch.randn(1, embed_dim))

    def forward(self, x):
        return self.memory.view(1, 1, -1).expand(x.size())

class TitansCore(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(embed_dim, embed_dim*4), nn.ReLU(), nn.Linear(embed_dim*4, embed_dim))
        self.norm1, self.norm2 = nn.LayerNorm(embed_dim), nn.LayerNorm(embed_dim)

    def forward(self, x, mem):
        x = self.norm1(x + self.attn(x, mem, mem)[0])
        return self.norm2(x + self.ff(x))

class SimVP(nn.Module):
    def __init__(self, shape_in, hid_S=16, embed_dim=128, memory_size=50, N_S=4):
        super().__init__()
        T, C, H, W = shape_in
        self.enc = Encoder(C, hid_S, N_S)
        self.hid = TitansVideoPrediction(T*hid_S, embed_dim, memory_size, H//(2**(N_S//2)), W//(2**(N_S//2)))
        self.dec = Decoder(hid_S, C, N_S)

    def forward(self, x_raw):
        B, T, C, H, W = x_raw.shape
        x = x_raw.view(B*T, C, H, W)

        embed, skip = self.enc(x)
        _, C_, H_, W_ = embed.shape

        z = embed.view(B, T, C_, H_, W_)
        hid = self.hid(z)
        hid = hid.view(B*T, C_, H_, W_)

        Y = self.dec(hid, skip)
        return Y.view(B, T, C, H, W)

if __name__ == "__main__":
    model = SimVP(shape_in=(10, 1, 64, 64))
    input_tensor = torch.randn(1, 10, 1, 64, 64)

    output = model(input_tensor)

    print("Input:", input_tensor.shape)
    print("Output:", output.shape)

