import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50

class InterpolationModel(nn.Module):
    def __init__(self, ):
        super(InterpolationModel, self).__init__()
        resnet_ckpt = "/path to resnet50-0676ba61.pth"
        self.resnet = resnet50(pretrained=False)
        self.resnet.load_state_dict(torch.load(resnet_ckpt))
        # self.temporal_transformer_layer = nn.TransformerDecoderLayer(d_model=512, nhead=4)
        # self.temporal_transformer = nn.TransformerDecoder(self.temporal_transformer_layer, num_layers=2)
        # self.input_proj = nn.Linear(3000, 512)
        self.fc = nn.Sequential(
            nn.Linear(3000, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def encode_video(self, video):
        B, T, C, H, W = video.shape
        video = video.view(B * T, C, H, W)
        with torch.no_grad():
            encoded_frames = self.resnet(video)           # (B*T, 1000)
        encoded_frames = encoded_frames.view(B, T, -1)  # (B, T, 1000)
        return encoded_frames

    def forward_embedding(self, video_embedding):
        # 计算相邻帧特征差分
        frame_pairs = torch.cat(
            [
                video_embedding[:, :-1, :], 
                video_embedding[:, 1:, :],
                video_embedding[:, 1:, :] - video_embedding[:, :-1, :]
            ],
            dim=-1
        )  

        predicted_gap = self.fc(frame_pairs)  # (B, T-1, out_dim)
        return predicted_gap.squeeze(-1)
    
    def forward(self, video):
        B, T, C, H, W = video.shape
        video = video.view(B * T, C, H, W)
        with torch.no_grad():
            encoded_frames = self.resnet(video)           # (B*T, 1000)
        encoded_frames = encoded_frames.view(B, T, -1)  # (B, T, 1000)

        frame_residue = torch.cat(
            [
                encoded_frames[:, :-1, :], 
                encoded_frames[:, 1:, :],
                encoded_frames[:, 1:, :] - encoded_frames[:, :-1, :]
            ],
            dim=-1
        )

        # 预测 gap
        predicted_gap = self.fc(frame_residue)  # (B, T-1, out_dim)
        return predicted_gap.squeeze(-1)
    
if __name__ == '__main__':
    from dataset import frame_interp_dataset
    from torch.utils.data import DataLoader

    dataset = frame_interp_dataset(
        data_path="/path to libero_dataset/finetune_dataset/libero_10_rpd17"
    )
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    model = InterpolationModel()
    ## number of parameters
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))
    
    for video, target in dataloader:
        print(video.shape)
        print(target.shape)
        predicted_gap = model(video)
        print(predicted_gap.shape)