import torch.nn.functional as F
import torch
import torch.nn as nn
import numpy as np
from functools import partial

from collections import OrderedDict

try:
    import torch_geometric
    from torch_geometric.nn import Sequential, GraphConv, DenseGraphConv
except:
    print("Graph conv. compositing not available as torch geometric not installed")

from einops import repeat,rearrange
from einops.layers.torch import Rearrange

from pdb import set_trace as pdb #debugging

from scipy.spatial.transform import Rotation as R
from copy import deepcopy

import torchvision
import util

import custom_layers
import geometry
import hyperlayers

import conv_modules

from torchgeometry import rtvec_to_pose

from torch.nn.functional import normalize


import math
def positionalencoding2d(d_model, height, width):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :return: d_model*height*width position matrix
    """
    if d_model % 4 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dimension (got dim={:d})".format(d_model))
    pe = torch.zeros(d_model, height, width)
    # Each dimension use half of d_model
    d_model = int(d_model / 2)
    div_term = torch.exp(torch.arange(0., d_model, 2) *
                         -(math.log(10000.0) / d_model))
    pos_w = torch.arange(0., width).unsqueeze(1)
    pos_h = torch.arange(0., height).unsqueeze(1)
    pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

    return pe

# SlotLFN but learning from video
class CLFN(nn.Module):

    def __init__(self, phi_latent=64, num_phi=1, phi_out_latent=128, 
                    phi_num_layers=3, img_feat_dim=64,sato_cpu=False):
        super().__init__()

        self.num_phi=num_phi

        num_hidden_units_phi = 256

        self.sato_cpu = sato_cpu
        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))
        self.phi = custom_layers.FCBlock(
                                hidden_ch=num_hidden_units_phi,
                                num_hidden_layers=1,
                                in_features=6,
                                out_features=phi_out_latent,
                                outermost_linear=True,)
        self.hyper_fg = hyperlayers.HyperNetwork(
                              hyper_in_features=phi_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=num_hidden_units_phi,
                              hypo_module=self.phi)
        self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=phi_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=num_hidden_units_phi,
                              hypo_module=self.phi)
        """
        self.fine_phi = custom_layers.FCBlock(
                                hidden_ch=num_hidden_units_phi,
                                num_hidden_layers=5,
                                in_features=6,
                                out_features=phi_out_latent,
                                outermost_linear=True,)
        self.hyper_fine_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=phi_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=num_hidden_units_phi,
                              hypo_module=self.fine_phi,)
        self.hyper_fine_fg = hyperlayers.HyperNetwork(
                              hyper_in_features=phi_latent,
                              hyper_hidden_layers=2,
                              hyper_hidden_features=num_hidden_units_phi,
                              hypo_module=self.fine_phi,)
        """
        self.num_hidden_units_phi = num_hidden_units_phi

        # Maps pixels to features for SlotAttention
        self.img_encoder = nn.Sequential(
                conv_modules.UnetEncoder(bottom=True,z_dim=img_feat_dim),
                Rearrange("b c x y -> b (x y) c")
        )

        print("NOTECHANGESLOTATTNBACKfrom learned emb")
        self.slot_encoder = custom_layers.SlotAttention(self.num_phi,
                                                       in_dim=img_feat_dim,
                                                       fg_slot_dim=phi_latent,
                                                       bg_slot_dim=phi_latent,
                                                       max_slot_dim=phi_latent)
        self.num_pos_enc=32
        self.slot_encoder_fine = custom_layers.SlotAttention(self.num_phi,
                                                       in_dim=img_feat_dim+self.num_pos_enc,
                                                       fg_slot_dim=phi_latent,
                                                       bg_slot_dim=phi_latent,
                                                       max_slot_dim=phi_latent)

        self.fine_feat_to_depth  = custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=phi_out_latent,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na')
        self.fine_depth_spreader = custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=2,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na')
        self.feat_to_depth  = custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=phi_out_latent,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na')
        self.depth_spreader = custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=2,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na')

        # Maps features to rgb
        self.fine_pix_gen = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=phi_out_latent,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        self.pix_gen = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=phi_out_latent,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        print(self)

    def compositor(self,feats,fine=False,initial_depth=None):
        feat_to_depth=self.fine_feat_to_depth if fine else self.feat_to_depth
        depth_spreader=self.fine_depth_spreader if fine else self.depth_spreader
        depth = feat_to_depth(feats)
        nodes = rearrange(depth,"p b q pix 1 -> (b q pix) p 1")
        min_depth = nodes.min(1,keepdim=True)[0].expand(-1,feats.size(0),1)
        attn = rearrange(depth_spreader(torch.cat((min_depth-nodes,nodes),-1)),
                            "(b q pix) p 1 -> p b q pix 1",
                            p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward(self,input,anneal=0):

        # Unpack input. For now, two queryies A,B

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)
        phi_intrinsics,phi_uv = [x.unsqueeze(0).expand(self.num_phi-1,-1,-1,-1)
                                        for x in (query_intrinsics,query_uv)]

        # Encode all imgs
        imsize = int(input["context"]["rgb"].size(-2)**(1/2))
        rgb_A = input["context"]["rgb"][:,0].permute(0,2,1).unflatten(-1,(imsize,imsize))

        # Create fg images: img_encoding -> slot attn -> compositor,rgb
        coarse=True
        torch.set_grad_enabled(coarse)

        imfeats = self.sato_wrap(self.img_encoder,rgb_A)
        slots_A, attn_A = self.slot_encoder(imfeats,iters=3) # b phi l 

        context_cam = input["context"]["cam2world"][:,0]

        world2contextcam = repeat(context_cam.inverse(),"b x y -> (b q) x y",q=n_qry)
        pose_A = repeat(world2contextcam @ cam2world,"bq x y -> p bq x y",p=self.num_phi-1)
        coords_A=geometry.plucker_embedding(pose_A,phi_uv,phi_intrinsics)
        bg_coords = geometry.plucker_embedding(cam2world,query_uv,query_intrinsics)
        coords = torch.cat([( torch.cat((bg_coords[None],coord))
                            ).flatten(0,1) for coord in [coords_A]])

        # Create phi
        coarse_slots_fg_rep=repeat(slots_A[:,1:],"b p l -> p (b q) l",q=n_qry)
        coarse_slots_bg_rep=repeat(slots_A[:,:1],"b p l -> p (b q) l",q=n_qry)
        coarse_fg_params = self.hyper_fg(coarse_slots_fg_rep)
        coarse_bg_params = self.hyper_bg(coarse_slots_bg_rep)
        coarse_phi_params=OrderedDict()
        for k in coarse_bg_params.keys(): 
            coarse_phi_params[k]=torch.cat([coarse_bg_params[k],coarse_fg_params[k]])
        coarse_feats = self.phi(coords,params=coarse_phi_params)
        coarse_feats = rearrange(coarse_feats, "(p b q) pix l -> p b q pix l",
                                                p=self.num_phi,b=b,q=n_qry)
        coarse_rgbs = self.pix_gen(coarse_feats)# AB p b q pix 3
        coarse_seg = self.compositor(coarse_feats)
        coarse_rgb  = (coarse_rgbs*coarse_seg).sum(0) # AB   b q pix 3

        torch.set_grad_enabled(True)

        if not coarse:
            pos_enc = positionalencoding2d(self.num_pos_enc,imsize,imsize
                                ).flatten(1,2).T[None].expand(b,-1,-1).cuda()
            fine_slots,fine_attn=self.slot_encoder_fine(torch.cat((imfeats,pos_enc),-1),slot=slots_A)
            fine_slots_fg_rep=repeat(fine_slots[:,1:],"b p l -> p (b q) l",q=n_qry)
            fine_slots_bg_rep=repeat(fine_slots[:,:1],"b p l -> p (b q) l",q=n_qry)
            fine_fg_params = self.hyper_fine_fg(fine_slots_fg_rep)
            fine_bg_params = self.hyper_fine_bg(fine_slots_bg_rep)
            fine_phi_params=OrderedDict()
            for k in fine_bg_params.keys(): 
                fine_phi_params[k]=torch.cat([fine_bg_params[k],fine_fg_params[k]])
            fine_feats = self.fine_phi(coords,params=fine_phi_params)
            fine_feats = rearrange(fine_feats, "(p b q) pix l -> p b q pix l",
                                                    p=self.num_phi,b=b,q=n_qry)
            fine_rgbs = self.fine_pix_gen(fine_feats)# AB p b q pix 3
            fine_seg = self.compositor(fine_feats)
            fine_rgb  = (fine_rgbs*fine_seg).sum(0) # AB   b q pix 3

        # Packup for loss fns and summaries
        out_dict = {
            "coarse_rgbs":   coarse_rgbs,
            "coarse_rgb":    coarse_rgb,
            "coarse_seg":    coarse_seg,
            "slot_attn": attn_A,
            "all_A_attn": attn_A,
            "A_attn": attn_A,
            "fg_slot_z":torch.cat([x.permute(1,0,2) for x in [slots_A]],1),
        }
        if not coarse:
            out_dict.update({
                "fine_rgbs":fine_rgbs,
                "fine_rgb": fine_rgb,
                "fine_seg": fine_seg,
                "fine_attn":fine_attn,
            })

        return out_dict


# MV training but FG is always queried in identity view for each view BG in canonical for both
class SVFGMVBGLFN(nn.Module):

    def __init__(self, latent_dim, num_phi=1, bg_latent=None, img_feat_dim=64, sato_cpu=False):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim = latent_dim
        self.bg_latent  = bg_latent
        self.num_phi    = num_phi

        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))

        self.phi = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=bg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)
        self.bg_bottleneck = nn.Linear(64,bg_latent)
        self.hyper_fg = hyperlayers.HyperNetwork(
                              hyper_in_features=64,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)

        # Maps pixels to features for SlotAttention
        self.img_encoder = conv_modules.UnetEncoder(bottom=False,z_dim=127)

        # Maps image features to set of latents 
        self.slot_encoder = custom_layers.SlotAttention(self.num_phi)

        self.composition_conv = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=1,  kernel_size=3, stride=1,  padding=1),
         )
        self.composition_conv.apply(custom_layers.init_weights_normal)

        # Maps features to rgb
        self.pix_gen = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        print(self)

    def compositor(self,feats):
        conv_in  = rearrange(feats,"p b q pix l -> (b q pix) l p")
        conv_out = self.sato_wrap(self.composition_conv,conv_in)
        attn     = rearrange(conv_out,"(b q pix) 1 p -> p b q pix 1",
                                    p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward(self,input):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        identity  = repeat(torch.eye(4).cuda(),"x y -> bq x y",bq=b*n_qry)
        fg_coords = geometry.plucker_embedding(identity, query_uv,query_intrinsics)
        bg_coords = geometry.plucker_embedding(cam2world,query_uv,query_intrinsics)
        coords    = torch.cat((bg_coords[None],
                            repeat(fg_coords,"bq x y -> p bq x y",p=self.num_phi-1)))

        # Encoding each image
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"].flatten(0,1).permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        # Create fg images: img_encoding -> slot attn -> compositor,rgb

        img_feats = self.sato_wrap(self.img_encoder,context_rgb)
        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        slot_z, slot_attn = self.slot_encoder( img_feats )

        # BG latent is shared across queries 
        slot_z = slot_z.permute(1,0,2).unflatten(1,(b,n_qry)) # p b q l
        slot_z[0,:,1] = slot_z[0,:,0]
        slot_z = slot_z.flatten(1,2)

        # Make and query lfn - stacking weights of bg and fg lfn
        param_bg = self.hyper_bg(self.bg_bottleneck(slot_z[:1,...]))
        param_fg = self.hyper_fg(slot_z[1:])
        phi_params=OrderedDict()
        for k in param_bg.keys(): phi_params[k]=torch.cat([param_bg[k],param_fg[k]])

        feats = rearrange( self.phi(coords.flatten(0,1),params=phi_params),
                      "(p b q) pix l -> p b q pix l",p=self.num_phi,b=b,q=n_qry)

        # Composite over phi dimension (first) to yield single scene
        soft_seg = self.compositor(feats)
        rgbs     = self.pix_gen(feats)
        rgb      = (rgbs*soft_seg).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  soft_seg.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z[1:],
        }

        return out_dict


# Repeating the same first successful experiment on simple chair
class RedoLFN(nn.Module):

    def __init__(self, latent_dim,num_phi=1,unet=True,slot_dim=64, fg_latent=None, 
                        bg_latent=None, img_feat_dim=64, sato_cpu=False):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim = latent_dim
        self.bg_latent  = bg_latent
        self.num_phi    = num_phi
        self.slot_dim   = slot_dim
        self.unet   = unet

        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))

        self.phi = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=bg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)
        self.bg_bottleneck = nn.Linear(slot_dim,bg_latent)
        self.hyper_fg = hyperlayers.HyperNetwork(
                              hyper_in_features=fg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)

        # Maps pixels to features for SlotAttention
        if self.unet:
            self.img_encoder = nn.Sequential( 
                conv_modules.UnetEncoder(bottom=False,z_dim=img_feat_dim),
                Rearrange("b c x y -> b (x y) c")
            )
        else:
            self.img_encoder = nn.Sequential( 
                conv_modules.FeaturePyramidEncoder(feature_scale=2),
                Rearrange("b c x y -> b (x y) c"),
                nn.Linear(512,img_feat_dim)
            )

        # Maps image features to set of latents 
        self.slot_encoder = custom_layers.SlotAttention(num_slots=self.num_phi, 
                in_dim=img_feat_dim+22, bg_slot_dim=slot_dim, fg_slot_dim=fg_latent, max_slot_dim=slot_dim,)

        """
        self.composition_conv = nn.Sequential(
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=1,  kernel_size=3, stride=1,  padding=1),
         )
        self.composition_conv.apply(custom_layers.init_weights_normal)
        """
        self.feat_to_depth = custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na')
        self.depth_spreader = custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=1,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na')

        # Maps features to rgb
        self.pix_gen = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        print(self)

    def compositor(self,feats):
        conv_out = self.sato_wrap(self.composition_conv,conv_in)
        attn     = rearrange(conv_out,"(b q pix) 1 p -> p b q pix 1",
                                    p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward(self,input,anneal=0):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        # Make query coords - first (bg) in world, second (fg) in context cam
        world2contextcam = repeat(query["cam2world"][:,0].inverse(),
                                  "b x y -> (b q) x y",q=n_qry)
        query_pose_fg = world2contextcam @ cam2world
        query_pose_bg = cam2world
        query_pose = torch.cat((query_pose_bg[None],
                               repeat(query_pose_fg, "n x y-> p n x y",p=self.num_phi-1)))
        coords = geometry.plucker_embedding(query_pose,
                repeat(query_uv,         "b n c -> phi b n c",phi=self.num_phi),
                repeat(query_intrinsics, "b n c -> phi b n c",phi=self.num_phi))
        coords.requires_grad_(True)

        # Context rgb is first query
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        # Create fg images: img_encoding -> slot attn -> compositor,rgb

        img_feats = self.sato_wrap(self.img_encoder,context_rgb)
        pos_emb = sin_emb(query["uv"][:,0])
        img_feats = torch.cat((img_feats,pos_emb),-1)

        slot_z, slot_attn = self.slot_encoder( img_feats, anneal )
        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)

        # Make and query lfn - stacking weights of bg and fg lfn
        z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)
        param_bg = self.hyper_bg(self.bg_bottleneck(slot_z[:1,...,:self.slot_dim]))
        param_fg = self.hyper_fg(slot_z[1:])
        phi_params=OrderedDict()
        for k in param_bg.keys(): phi_params[k]=torch.cat([param_bg[k],param_fg[k]])

        feats = rearrange( self.phi(coords.flatten(0,1),params=phi_params),
                      "(p b q) pix l -> p b q pix l",p=self.num_phi,b=b,q=n_qry)

        # Composite over phi dimension (first) to yield single scene
        soft_seg = self.compositor(feats)
        rgbs     = self.pix_gen(feats)
        rgb      = (rgbs*soft_seg).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  soft_seg.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z[1:],
        }

        return out_dict


class FGLFN(nn.Module):

    def __init__(self, latent_dim, num_phi=1, slot_dim=64,
                 img_feat_dim=64, sato_cpu=False):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim = latent_dim
        self.slot_dim   = slot_dim
        self.num_phi    = num_phi

        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))

        self.phi = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=12,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_phi = hyperlayers.HyperNetwork(
                              hyper_in_features=slot_dim,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)

        # Maps pixels to features for SlotAttention
        self.img_encoder = conv_modules.FeaturePyramidEncoder(num_layers=3,
                                        use_first_pool=False,feature_scale=2,)
        self.img_feat_downscale = nn.Linear(256,img_feat_dim-2)

        # Maps image features to set of latents 
        self.slot_encoder = custom_layers.SlotAttentionFG(self.num_phi,
                in_dim=img_feat_dim, slot_dim=slot_dim,)

        self.composition_conv = nn.Sequential(
            nn.Conv1d(in_channels=latent_dim, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=1,  kernel_size=3, stride=1,  padding=1),
         )
        self.composition_conv.apply(custom_layers.init_weights_normal)

        # Maps features to rgb
        self.pix_gen = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        print(self)

    def compositor(self,feats):
        conv_in  = rearrange(feats,"p b q pix l -> (b q pix) l p")
        conv_out = self.sato_wrap(self.composition_conv,conv_in)
        attn     = rearrange(conv_out,"(b q pix) 1 p -> p b q pix 1",
                                    p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward(self,input):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        # Make query coords - first (bg) in world, second (fg) in context cam
        world2contextcam = repeat(query["cam2world"][:,0].inverse(),
                                  "b x y -> (b q) x y",q=n_qry)
        cam_crds = geometry.plucker_embedding(world2contextcam@cam2world,
                                                query_uv, query_intrinsics)
        world_crds = geometry.plucker_embedding(cam2world,
                                                query_uv, query_intrinsics)
        crds = torch.cat((cam_crds,world_crds),-1)
        crds.requires_grad_(True)

        # Context rgb is first query
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        # Create fg images: img_encoding -> slot attn -> compositor,rgb

        img_feats = self.sato_wrap(self.img_encoder,context_rgb)
        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        img_feats = self.img_feat_downscale(img_feats)
        img_feats = torch.cat((query["context_uv"][:,0],img_feats),-1)
        slot_z, slot_attn = self.slot_encoder( img_feats )
        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)

        # Make and query lfn - stacking weights
        phi_params = self.hyper_phi(slot_z)
        phi_crds   = repeat(crds,"bq pix c -> (p bq) pix c",p=self.num_phi)
        feats = self.phi(phi_crds,params=phi_params)
        feats = rearrange(feats,"(p b q) pix l -> p b q pix l",p=self.num_phi,b=b,q=n_qry)

        # Composite over phi dimension (first) to yield single scene
        soft_seg = self.compositor(feats)
        rgbs     = self.pix_gen(feats)
        rgb      = (rgbs*soft_seg).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  soft_seg.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z,
        }

        return out_dict

class AttnCompLFN(nn.Module):

    def __init__(self, latent_dim, num_phi=1,
                    max_slot_dim=64,fg_latent=64, bg_latent=None,use_gan=False,
                    img_feat_dim=64, sato_cpu=False):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim     = latent_dim
        self.bg_latent      = bg_latent
        self.max_slot_dim   = max_slot_dim
        self.num_phi        = num_phi
        self.use_gan        = use_gan

        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))

        self.phi = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6,
                                out_features=3,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=bg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)

        # Maps pixels to features for SlotAttention
        #self.img_encoder = conv_modules.UnetEncoder(bottom=True)
        self.img_encoder = conv_modules.FeaturePyramidEncoder(num_layers=3,
                                        use_first_pool=False,feature_scale=1,)
        self.img_feat_downscale = nn.Linear(256,img_feat_dim-2)

        # Maps image features to set of latents 
        self.slot_encoder = custom_layers.SlotAttention(self.num_phi,
                in_dim=img_feat_dim, max_slot_dim=max_slot_dim,fg_slot_dim=fg_latent,
                bg_slot_dim=bg_latent,)

        print(self)

    def compositor(self,feats): pass

    def forward(self,input):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        # Context rgb is first query
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        img_feats = self.sato_wrap(self.img_encoder,context_rgb)
        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        img_feats = self.img_feat_downscale(img_feats)
        img_feats = torch.cat((query["uv"].squeeze(1),img_feats),-1)

        slot_z, slot_attn = self.slot_encoder( img_feats )
        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)

        # Render bg img
        coords_bg = geometry.plucker_embedding(cam2world, query_uv,query_intrinsics,)
        coords_bg.requires_grad_(True)
        bg_rgb = self.phi(coords_bg,params=self.hyper_bg(slot_z[0])).tanh()
        bg_rgb = rearrange(bg_rgb,"b pix c -> 1 b 1 pix c")

        # FG img is mean of color where it attends to
        fg_attn = slot_attn[:,1:]
        fg_mean_color = (repeat(query["rgb"],"b 1 pix c ->b p pix c",p=self.num_phi-1) *
                    slot_attn[:,1:,:,None]).mean(-2)
        fg_rgbs = repeat(fg_mean_color,"b p c -> p b 1 pix c", pix = query["rgb"].size(2))

        # Composition is attention
        comp = rearrange(slot_attn,"b p pix -> p b 1 pix 1")
        rgbs = torch.cat((bg_rgb,fg_rgbs))
        rgb  = (rgbs*comp).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  comp.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z[1:],
        }

        return out_dict



# Shallow conv maps image to 1d for BG phi
class CompLFNClean(nn.Module):

    def __init__(self, latent_dim, num_phi=1,
                    max_slot_dim=64,fg_latent=64, bg_latent=None,use_gan=False,
                    img_feat_dim=64, sato_cpu=False):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim     = latent_dim
        self.bg_latent      = bg_latent
        self.max_slot_dim   = max_slot_dim
        self.num_phi        = num_phi
        self.use_gan        = use_gan

        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))

        self.phi = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=bg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)
        self.hyper_fg = hyperlayers.HyperNetwork(
                              hyper_in_features=fg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi)

        # Maps pixels to features for SlotAttention
        self.img_encoder = conv_modules.UnetEncoder(bottom=False)

        # Maps image features to set of latents 
        self.slot_encoder = custom_layers.SlotAttention(self.num_phi,
                in_dim=img_feat_dim, max_slot_dim=max_slot_dim,fg_slot_dim=fg_latent,
                bg_slot_dim=bg_latent,)

        self.composition_conv = nn.Sequential(
            nn.Conv1d(in_channels=latent_dim, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=1,  kernel_size=3, stride=1,  padding=1),
         )
        self.composition_conv.apply(custom_layers.init_weights_normal)

        # Maps features to rgb
        self.pix_gen_bg = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        self.pix_gen_fg = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        print(self)

    def compositor(self,feats):
        conv_in  = rearrange(feats,"p b q pix l -> (b q pix) l p")
        conv_out = self.sato_wrap(self.composition_conv,conv_in)
        attn     = rearrange(conv_out,"(b q pix) 1 p -> p b q pix 1",
                                    p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward(self,input):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        # Make query coords - first (bg) in world, second (fg) in context cam
        world2contextcam = repeat(query["cam2world"][:,0].inverse(),
                                  "b x y -> (b q) x y",q=n_qry)
        query_pose_fg = world2contextcam @ cam2world
        query_pose_bg = cam2world
        query_pose = torch.cat((query_pose_bg[None],
                               repeat(query_pose_fg, "n x y-> p n x y",p=self.num_phi-1)))
        coords = geometry.plucker_embedding(query_pose,
                repeat(query_uv,         "b n c -> phi b n c",phi=self.num_phi),
                repeat(query_intrinsics, "b n c -> phi b n c",phi=self.num_phi))
        coords.requires_grad_(True)

        # Context rgb is first query
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        # Create fg images: img_encoding -> slot attn -> compositor,rgb

        img_feats = self.sato_wrap(self.img_encoder,context_rgb)
        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        slot_z, slot_attn = self.slot_encoder( img_feats )
        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)

        # Make and query lfn - stacking weights of bg and fg lfn
        z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)
        param_bg = self.hyper_bg(slot_z[:1,...,:self.bg_latent])
        param_fg = self.hyper_fg(slot_z[1:])
        phi_params=OrderedDict()
        for k in param_bg.keys(): phi_params[k]=torch.cat([param_bg[k],param_fg[k]])

        feats = rearrange( self.phi(coords.flatten(0,1),params=phi_params),
                      "(p b q) pix l -> p b q pix l",p=self.num_phi,b=b,q=n_qry)

        # Composite over phi dimension (first) to yield single scene
        soft_seg = self.compositor(feats)
        rgbs     = torch.cat((self.pix_gen_bg(feats[:1]),
                              self.pix_gen_fg(feats[1:])))
        rgb      = (rgbs*soft_seg).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  soft_seg.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z[1:],
        }

        return out_dict

# Shallow conv maps image to 1d for BG phi
class BgFgSep(nn.Module):

    def __init__(self, latent_dim, fit_single=False, depth=True, num_phi=1,
                    use_bg=True,concat_hyper=False,bg_training=False,
                    max_slot_dim=64,fg_latent=64,img_encoder="unet",
                    bg_latent=None,use_gan=False,query_encodings=False,
                     temp=5e0,use_gt_seg=False, regress_bg=False,
                     use_gt_w2m=False, compositor="depth",
                     rand_noise_seg=False, img_feat_dim=64,
                    sato_cpu=False):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim     = latent_dim
        self.bg_latent      = bg_latent
        self.max_slot_dim   = max_slot_dim
        self.fit_single     = fit_single
        self.num_phi        = num_phi
        self.temp           = temp
        self.use_gt_seg     = use_gt_seg
        self.regress_bg     = regress_bg
        self.use_gt_w2m     = use_gt_w2m
        self.use_gan        = use_gan
        self.use_bg         = use_bg
        self.query_encodings= query_encodings
        self.rand_noise_seg = rand_noise_seg
        self.bg_training    = bg_training

        self.sato_wrap = (lambda mod,x: mod.cpu()(x.cpu()).cuda())\
                                 if sato_cpu else (lambda mod,x:mod(x))

        self.compositor = self.conv_compositor

        """
        # Bg will use concat
        self.phi_bg = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6+bg_latent,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        """
        self.phi_bg = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=bg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi_bg)

        # Fg will use hypernet
        self.phi_fg = custom_layers.FCBlock(
                                hidden_ch=128,
                                num_hidden_layers=3,
                                in_features=6,
                                out_features=latent_dim,
                                outermost_linear=True,
                                norm='layernorm_na')
        self.hyper_fg = hyperlayers.HyperNetwork(
                              hyper_in_features=fg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=128,
                              hypo_module=self.phi_fg)

        # Maps img to 1D for BG
        self.bg_img_encoder = nn.Sequential(
                nn.Conv2d(3,  32, kernel_size=4, stride=2), nn.ReLU(), torch.nn.BatchNorm2d(32),
                nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), torch.nn.BatchNorm2d(64),
                nn.Conv2d(64, 128,kernel_size=4, stride=2), nn.ReLU(), torch.nn.BatchNorm2d(128),
                nn.Conv2d(128,256,kernel_size=4, stride=2), torch.nn.BatchNorm2d(256), nn.ReLU(),
                nn.Flatten(), nn.Linear(1024,256), nn.ReLU(), nn.Linear(256,1),
                            )
        # Maps pixels to features for SlotAttention
        self.fg_img_encoder = conv_modules.UnetEncoder(bottom=False)

        # Maps image features to set of latents 
        self.slot_encoder = custom_layers.SlotAttentionFG(self.num_phi-1,
                in_dim=img_feat_dim, slot_dim=fg_latent)

        self.composition_conv = nn.Sequential(
            nn.Conv1d(in_channels=latent_dim, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=64, out_channels=1,  kernel_size=3, stride=1,  padding=1),
         )
        self.composition_conv.apply(custom_layers.init_weights_normal)

        # Maps features to rgb
        self.pix_gen_bg = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        self.pix_gen_fg = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=3, outermost_linear=True,
                        norm='layernorm_na'), nn.Tanh() )
        self.alpha_gen_fg = nn.Sequential( custom_layers.FCBlock(
                        hidden_ch=128, num_hidden_layers=3, in_features=latent_dim,
                        out_features=1, outermost_linear=True,
                        norm='layernorm_na'), nn.Sigmoid() )
        print(self)

    def conv_compositor(self,feats):
        conv_in  = rearrange(feats,"p b q pix l -> (b q pix) l p")
        conv_out = self.sato_wrap(self.composition_conv,conv_in)
        attn     = rearrange(conv_out,"(b q pix) 1 p -> p b q pix 1",
                                    p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward(self,input):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)
        world2contextcam = repeat(query["cam2world"][:,0].inverse(),
                                  "b x y -> (b q) x y",q=n_qry)
        query_pose = (world2contextcam @ cam2world).unsqueeze(0).expand(
                                                        self.num_phi-1,-1,-1,-1)
        coords_bg = geometry.plucker_embedding(cam2world, query_uv,query_intrinsics,)
        coords_bg.requires_grad_(True)
        coords_fg = geometry.plucker_embedding(query_pose,
                repeat(query_uv,         "b n c -> phi b n c",phi=self.num_phi-1),
                repeat(query_intrinsics, "b n c -> phi b n c",phi=self.num_phi-1))
        coords_fg.requires_grad_(True)

        # Context rgb is first query
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        # Create bg image: img encoding -> hyper -> rgb
        bg_z = self.sato_wrap(self.bg_img_encoder,context_rgb)
        bg_params = self.hyper_bg(bg_z)
        bg_feats = self.phi_bg(coords_bg,params=bg_params).unflatten(0,(b,n_qry))
        bg_rgb = self.pix_gen_bg(bg_feats)
        bg_err = (.5+(bg_rgb[:,0]-query["rgb"][:,0])/2).mean(-1)

        if self.bg_training: return {"rgb":bg_rgb,"rgbs":bg_rgb[None],}

        # Create fg images: img_encoding -> slot attn -> compositor,rgb

        #img_feats = self.fg_img_encoder(context_rgb)
        img_feats = self.sato_wrap(self.fg_img_encoder,context_rgb)
        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        slot_z, slot_attn = self.slot_encoder( img_feats, bg_err )
        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)

        # Make and query lfn - stacking weights of bg and fg lfn
        phi_params = self.hyper_fg(slot_z)
        fg_feats = rearrange( self.phi_fg(coords_fg.flatten(0,1),params=phi_params),
                      "(p b q) pix l -> p b q pix l",p=self.num_phi-1,b=b,q=n_qry)
        fg_rgb = self.pix_gen_fg(fg_feats)
        fg_alpha = self.alpha_gen_fg(fg_feats)

        # Composite. BG not involved in compositing just behind the alpha of queries
        fg_weights = self.compositor(fg_feats)
        fg_mask = fg_weights*fg_alpha
        bg_mask = 1-fg_mask.sum(0)
        rgb = (fg_mask*fg_rgb).sum(0) + bg_mask*bg_rgb
        
        # For some amt of training set the borders to be just BG
        margin=3
        border_mask = torch.ones(b,n_qry,img_sl,img_sl,1).cuda()
        border_mask[:,:,margin:-margin,margin:-margin] = 0
        border_mask = border_mask.flatten(2,3)
        inv_border_mask = (~border_mask.bool()).float()
        rgb = bg_rgb*border_mask + rgb*inv_border_mask
    
        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      torch.cat((bg_rgb[None],fg_rgb)),
            "rgb":       rgb,
            "soft_seg":  torch.cat((bg_mask[None],fg_mask)).unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z,
        }

        return out_dict

class CompLFN(nn.Module):

    def __init__(self, latent_dim, fit_single=False, depth=True, num_phi=1,
                    use_bg=True,concat_hyper=False,
                    max_slot_dim=64,fg_latent=64,img_encoder="unet",
                    bg_latent=None,use_gan=False,query_encodings=False,
                     temp=5e0,use_gt_seg=False, regress_bg=False,
                     use_gt_w2m=False, compositor="depth",
                     rand_noise_seg=False, img_feat_dim=64):
        super().__init__()
        print(f"num phi:{num_phi}")

        num_hidden_units_phi = 256

        self.latent_dim     = latent_dim
        self.bg_latent      = bg_latent
        self.max_slot_dim   = max_slot_dim
        self.fit_single     = fit_single
        self.num_phi        = num_phi
        self.temp           = temp
        self.use_gt_seg     = use_gt_seg
        self.regress_bg     = regress_bg
        self.use_gt_w2m     = use_gt_w2m
        self.use_gan        = use_gan
        self.use_bg         = use_bg
        self.query_encodings= query_encodings
        self.rand_noise_seg = rand_noise_seg
        self.rand_noise_seg = rand_noise_seg

        self.compositor = {
            "conv"  : self.conv_compositor,
            "graph" : self.graph_compositor,
            "custom": self.custom_compositor,
        }[compositor]

        phi_default = (lambda out=latent_dim,in_=6,h=4,h_ch=num_hidden_units_phi:
                                                custom_layers.FCBlock(
                                                hidden_ch=h_ch,
                                                num_hidden_layers=h,
                                                in_features=in_,
                                                out_features=out,
                                                outermost_linear=True,
                                                norm='layernorm_na'))

        self.unet=True  
        if self.unet:
            self.img_encoder = conv_modules.UnetEncoder(bottom=False)
        else:
            self.img_encoder = conv_modules.FeaturePyramidEncoder(
                                            use_first_pool=False,feature_scale=2,)
            self.img_feat_downscale = nn.Linear(512,img_feat_dim-2)

        self.slot_encoder = custom_layers.SlotAttention(self.num_phi,
                in_dim=img_feat_dim,max_slot_dim=max_slot_dim,
                fg_slot_dim=fg_latent,bg_slot_dim=bg_latent)

        self.forward = self.forward_with_bg if use_bg else self.forward_without_bg

        self.phi = phi_default()
        hyper = lambda x : hyperlayers.HyperNetwork(
                              hyper_in_features=x,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=latent_dim,
                              hypo_module=self.phi)
        self.hyper_fg = hyper(fg_latent)
        self.hyper_bg = hyper(bg_latent)
        if self.bg_latent is not None:
            print("Note doing bottleneck for gan training")
            self.bg_bottleneck = nn.Linear(max_slot_dim,self.bg_latent)
            """
            self.hyper_bg = hyperlayers.HyperNetwork(
                              hyper_in_features=self.bg_latent,
                              hyper_hidden_layers=1,
                              hyper_hidden_features=latent_dim,
                              hypo_module=self.phi)
            """
        else:
            self.bg_bottleneck = lambda x : x
            self.hyper_bg = default_hyper()

        self.feature_embedding = phi_default(out=64,in_=latent_dim,h=0)
        if compositor=="graph":
            self.graph_conv = Sequential('x, edge_index', [
                (DenseGraphConv(64, 64,aggr="max"), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                (DenseGraphConv(64, 64,aggr="max"), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                (DenseGraphConv(64, 64,aggr="max"), 'x, edge_index -> x'),
                nn.ReLU(inplace=True),
                nn.Linear(64, 1), ])
            self.graph_conv.apply(custom_layers.init_weights_normal)
        elif compositor=="conv":
            self.composition_conv = nn.Sequential(
                nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv1d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv1d(in_channels=64, out_channels=1,  kernel_size=3, stride=1,  padding=1),
             )
            self.composition_conv.apply(custom_layers.init_weights_normal)
        elif compositor=="custom":
            self.feat_to_depth  = phi_default(out=1,in_=latent_dim,h=3,h_ch=256)
            self.depth_spreader = phi_default(out=1,in_=2,h=2,h_ch=128)

        # Maps features to attention scores if using depth and rgb
        self.pixel_generator = nn.Sequential( phi_default(out=3,in_=latent_dim), 
                                              nn.Tanh() )

        #if self.use_gan:
        self.discriminator = custom_layers.Discriminator(64,64)

        print(self)

    def graph_compositor(self,feats):
        feat_emb = self.feature_embedding(feats)
        nodes = rearrange(feat_emb,"p b q pix l -> (b q pix) p l")
        adj = torch.ones(nodes.size(0),feats.size(0),feats.size(0))
        nodes_out = self.graph_conv(nodes,adj)
        attn = rearrange(nodes_out,"(b q pix) p 1 -> p b q pix 1",
                            p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def conv_compositor(self,feats):
        feat_emb = self.feature_embedding(feats)
        conv_in  = rearrange(feat_emb,"p b q pix l -> (b q pix) l p")
        conv_out = self.composition_conv(conv_in)
        attn     = rearrange(conv_out,"(b q pix) 1 p -> p b q pix 1",
                                    p=feats.size(0),b=feats.size(1),q=feats.size(2))
        if self.rand_noise_seg: attn = torch.rand_like(attn)/3+attn.sigmoid()
        return attn.softmax(0)+1e-9

    def custom_compositor(self,feats):
        depth = self.feat_to_depth(feats).relu()
        #depth = torch.cat((torch.ones_like(feats[:1])*10,
        #                self.feat_to_depth(feats[1:]).relu()))
        nodes = rearrange(depth,"p b q pix 1 -> (b q pix) p 1")
        min_depth = nodes.min(1,keepdim=True)[0]
        paired_depth = torch.cat((nodes,min_depth.expand(-1,self.num_phi,1)),-1)
        attn = rearrange(self.depth_spreader(-paired_depth),
                            "(b q pix) p 1 -> p b q pix 1",
                            p=feats.size(0),b=feats.size(1),q=feats.size(2))
        return attn.softmax(0)+1e-9

    def forward_without_bg(self,input):
        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        # Make query coords - first (bg) in world, second (fg) in context cam
        world2contextcam = repeat(query["cam2world"][:,0].inverse(),
                                  "b x y -> (b q) x y",q=n_qry)
        query_pose = (world2contextcam @ cam2world).unsqueeze(0).expand(
                                                        self.num_phi,-1,-1,-1)

        coords = geometry.plucker_embedding(query_pose,
                repeat(query_uv,         "b n c -> phi b n c",phi=self.num_phi),
                repeat(query_intrinsics, "b n c -> phi b n c",phi=self.num_phi))
        coords.requires_grad_(True)

        # Map first query image into slots 
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))

        #print("testing w rand cus satori broken")
        #img_feats=torch.rand(b,512,64,64).cuda()
        img_feats = self.img_encoder(context_rgb)

        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        if not self.unet:
            img_feats = self.img_feat_downscale(img_feats)
            img_feats = torch.cat((query["uv"].squeeze(1),img_feats),-1)

        slot_z, slot_attn = self.slot_encoder.forward_no_bg( img_feats )
        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)

        # Make and query lfn - stacking weights of bg and fg lfn
        phi_params = self.hyper_fg(slot_z)
        feats = rearrange( self.phi(coords.flatten(0,1),params=phi_params),
                      "(p b q) pix l -> p b q pix l",p=self.num_phi,b=b,q=n_qry)

        # Composite over phi dimension (first) to yield single scene
        soft_seg = self.compositor(feats)
        rgbs     = self.pixel_generator(feats)
        rgb      = (rgbs*soft_seg).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  soft_seg.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z,
        }

        return out_dict

    def forward_with_bg(self,input):

        # Unpack input

        query = input['query']
        b, n_ctxt = query["uv"].shape[:2]
        n_qry, n_pix = query["uv"].shape[1:3]
        cam2world, query_intrinsics, query_uv = util.get_query_cam(input)

        # Make query coords - first (bg) in world, second (fg) in context cam
        world2contextcam = repeat(query["cam2world"][:,0].inverse(),
                                  "b x y -> (b q) x y",q=n_qry)
        query_pose_fg    = world2contextcam @ cam2world
        query_pose_bg    =                    cam2world
        query_pose = torch.stack([query_pose_bg]+
                                 [query_pose_fg for _ in range(self.num_phi-1)])

        coords = geometry.plucker_embedding(query_pose,
                repeat(query_uv,         "b n c -> phi b n c",phi=self.num_phi),
                repeat(query_intrinsics, "b n c -> phi b n c",phi=self.num_phi))
        coords.requires_grad_(True)

        # Map first query image into slots 
        img_sl = int(query["input_rgb"].size(-2)**(1/2))
        context_rgb = query["input_rgb"][:,0].permute(0,2,1).unflatten(-1,(img_sl,img_sl))
        img_feats = self.img_encoder(context_rgb)
        img_feats = rearrange(img_feats, "b c x y -> b (x y) c")
        if False:
            img_feats = self.img_feat_downscale(img_feats)
            img_feats = torch.cat((query["uv"].squeeze(1),img_feats),-1)
        slot_z, slot_attn = self.slot_encoder( img_feats )

        slot_z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)
        #fg_latent = repeat(fg_slot_z,"b p l -> p (b q) l",q=n_qry)
        #bg_latent = repeat(bg_slot_z,"b 1 l -> 1 (b q) l",q=n_qry)

        # Make and query lfn - stacking weights of bg and fg lfn
        #z = repeat(slot_z,"b p l -> p (b q) l",q=n_qry)
        param_bg = self.hyper_bg(slot_z[:1,...,:self.bg_latent])
        param_fg = self.hyper_fg(slot_z[1:])
        phi_params=OrderedDict()
        for k in param_bg.keys(): phi_params[k]=torch.cat([param_bg[k],param_fg[k]])

        feats = rearrange( self.phi(coords.flatten(0,1),params=phi_params),
                      "(p b q) pix l -> p b q pix l",p=self.num_phi,b=b,q=n_qry)

        # Composite over phi dimension (first) to yield single scene
        soft_seg = self.compositor(feats)
        rgbs     = self.pixel_generator(feats)
        rgb      = (rgbs*soft_seg).sum(0)

        # Packup for loss fns and summaries
        out_dict = {
            "rgbs":      rgbs,
            "rgb":       rgb,
            "soft_seg":  soft_seg.unsqueeze(2),
            "slot_attn": slot_attn.permute(1,0,2).unsqueeze(2),
            "fg_slot_z": slot_z[1:],
        }

        if self.use_gan:

            img_real = query["input_rgb"].flatten(0,1).permute(0,2,1).unflatten(-1,(img_sl,img_sl))
            rgb_sl = int(rgb.size(-2)**(1/2))
            img_fake = rgb.flatten(0,1).permute(0,2,1).unflatten(-1,(rgb_sl,rgb_sl))
            if rgb_sl!=img_sl: img_fake = F.interpolate( img_fake, (img_sl,img_sl) )

            disc_out = (self.discriminator(x) for x in (img_fake,img_real))
            out_dict["disc_out"]=disc_out

        return out_dict
