from torch.nn import ModuleDict
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
from models.dit import DiT
from models.pyg_att import Matformer
from models.pyg_att import MatformerConfig
from models.flow_matching import FlowMatchingInterpolant
from torch import nn
import torch
import pdb

class LatentDiffusion(nn.Module):
    def __init__(
        self,
        autoencoder_ckpt: str,
        config: MatformerConfig = MatformerConfig(name="matformer")
    ) -> None:
        super().__init__()
        self.autoencoder_ckpt = autoencoder_ckpt
        model = torch.load(autoencoder_ckpt)
        
        self.net = Matformer(config)
        self.net.load_state_dict(model,strict=False)
        
        self.denoiser = DiT(d_x=128, d_model=768, nhead=12, num_layers=12, num_datasets=1)
        
        self.interpolent = FlowMatchingInterpolant(min_t=1e-2, corrupt=True, num_timesteps=100,self_condition=True, self_condition_prob=0.5)
        
    def forward(self, batch: Data):
        collect_dict = {}
        
        data, ldata = batch
        
        x_1 = self.net(batch)

        # Convert from PyG batch to dense batch with padding
        x_1, mask = to_dense_batch(x_1, data.batch)
        
        
        
        dense_encoded_batch = {"x_1": x_1, "token_mask": mask, "diffuse_mask": mask}
        
        

        # Corrupt batch using the interpolant
        
        self.interpolent.device = dense_encoded_batch["x_1"].device
        noisy_dense_encoded_batch = self.interpolent.corrupt_batch(dense_encoded_batch)

        x_sc = None

        # Run denoiser model
        pred_x = self.denoiser(
            x=noisy_dense_encoded_batch["x_t"],
            t=noisy_dense_encoded_batch["t"],
            dataset_idx=None,
            spacegroup=None,
            mask=mask,
            x_sc=x_sc,
        )

        collect_dict['pred_x']=pred_x
        collect_dict['noisy_dense_encoded_batch']=noisy_dense_encoded_batch

        return collect_dict