import torch
import torch.nn as nn

from .base import Preprocessor
from architectures.pointnet import PointNet2SetEncoder
from architectures.main_architectures import MLP, CNN


class ConCat(Preprocessor):

    def __init__(self, no_state, state_dim, img_dim, z_dim, action_dim, modalities, configs_rl, device, n_mods=2):

        if modalities == 'all' and not no_state:
            final_z_dim = z_dim * (n_mods+1)
        elif modalities == 'all' and no_state:
            final_z_dim = z_dim * n_mods
        else:
            final_z_dim = z_dim
        super(ConCat, self).__init__(state_dim, img_dim, final_z_dim, modalities, device)

        phi = {}
        if (modalities in ['state', 'all']) and not no_state:
            phi['state'] = MLP(state_dim, z_dim, configs_rl['architecture'])
        if modalities in ['image', 'all']:
            phi['image'] = CNN(3*3, z_dim, configs_rl['architecture'])
        if modalities in ['depth', 'all']:
            phi['depth'] = CNN(1*3, z_dim, configs_rl['architecture'])
        if modalities in ['pointcloud', 'all']:
            phi['pointcloud'] = PointNet2SetEncoder(in_feat_dim=3, z_dim=z_dim, n_frames=3)

        self.phi = nn.ModuleDict(phi)

    def get_representation(self, obs, past_state_action=None, phase="collect"):
        z_dict = {}
        for mode, x in obs.items():
            if obs[mode] is None:
                continue
            if type(obs[mode]) == dict and obs[mode]['pc'] is None:
                continue
            if mode in self.phi.keys():
                z_dict[mode] = self.phi[mode](x)
        return torch.cat([z_dict[m] for m in sorted(z_dict.keys())], -1)











