import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusion import GaussianDiffusionSampler

from SA import SAEncoder
from DiT_model import DiT

class Encoder(nn.Module):
    def __init__(self, config):

        super(Encoder, self).__init__()

        self.SA=SAEncoder(config)       
        

    def forward(self, x, num_slots):
        batch_size, C, H, W = x.shape #B 3 H W

        obj_lat,full_m=self.SA(x)

        full_m=F.interpolate(full_m.flatten(end_dim=1), scale_factor=4, mode='nearest').view(batch_size,-1,1,H,W)

        input_image_save=input_image = x.unsqueeze(1).expand(-1,full_m.shape[1],-1,-1,-1) #B 1+N 3 H W

        output = {'obj_latents': obj_lat,
                  'infer_mask': full_m,
                  'image':input_image_save,}

        return output
    
class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()

        self.img_size = img_size = config['image_shape'][-1]

        self.register_buffer('eps', torch.tensor(1e-5).float())

        config['DiT']['input_size']=self.img_size//8
        self.DiT=DiT(**config['DiT'])

        self.sampler=GaussianDiffusionSampler(self.DiT,config['beta_1'], config['beta_T'], config['T'], self.img_size,config['mean_type'], config['var_type'])
        
        

    def forward(self, result_infer,uncondition=False):

        obj_latents = result_infer['obj_latents']
        batch_size, num_slots, slot_size = obj_latents.shape

        infer_mask = result_infer['infer_mask'] #B 1+N 1 H W
        B,N,_,H,W=infer_mask.shape


        all_slots=obj_latents #B 1+N slot_size

        mask=None
        apc, mask=self.DiT(result_infer['noised_image'],result_infer['t'],all_slots) # B,num_slots(+1), d_VQGAN(1), H_diff, W_diff

        d_VQGAN, H_diff, W_diff=apc.shape[-3:]

        mask = F.log_softmax(mask, dim=1)
        mask = mask.exp()  # B,num_slots, 1, H_diff, W_diff
        predict_noise=(mask*apc).sum(dim=1)     # B, d_VQGAN, H_diff, W_diff
        mask=mask.reshape(-1,1,H_diff, W_diff) # B*num_slots, 1, H_diff, W_diff

        gen_mask_origin=mask.reshape(-1,N,mask.shape[1],mask.shape[2],mask.shape[3])
        mask=F.interpolate(mask, scale_factor=8, mode='nearest')
        mask=mask.reshape(B,-1,1,H,W) #B 1+N 1 H W

        output = {'mask': mask, 'predict_noise':predict_noise}
        
        output.update(result_infer)
        return output