import torch
import torch.nn as nn
from Network.network import Network
from Network.net_types import network_type
from Network.Dists.net_dist_utils import init_key_query, init_select_args, init_inter_args
from Network.network_utils import assign_distribution, pytorch_model
from Network.Dists.mask_utils import get_passive_mask, get_active_mask

class InteractionSelectionMaskNetwork(Network):
    def __init__(self, args):
        super().__init__(args)
        self.num_clusters = args.cluster.num_clusters
        assert(self.num_clusters >= 3)
        self.ip = args.net_inter
        self.fp = args.factor
        self.key_args, self.query_args, self.key_query_encoder = \
            init_key_query(args) # even if we don't embed, init a key_query that just slices
        select_args = init_select_args(args)
        self.selection_network = network_type[select_args.net_type](select_args)

        self.soft_inter_dist = assign_distribution("RelaxedHot")
        self.hard_inter_dist = assign_distribution("CategoricalHot")

        # inter models must operate by pointnet principles to be instance invariant, keypair, keyembed, mask_attn, raw_attn
        inter_args = init_inter_args(args, use_cluster = False)
        inter_args.hidden_sizes = args.cluster.cluster_inter_hidden
        self.inter_models = nn.ModuleList([network_type[inter_args.net_type](inter_args) for i in range(self.num_clusters)]) # two clusters reserved, one for passive and one for full

        self.softmax = nn.Softmax(dim=-1)
        self.model = [self.selection_network, self.inter_models]

        self.train()
        self.reset_network_parameters()

    def reset_environment(self, factor_params):
        self.fp = factor_params
        self.key_query_encoder.reset_environment(factor_params)
        if hasattr(self.selection_network, "reset_environment"): 
            self.selection_network.reset_environment(factor_params)
        if hasattr(self.inter_models[0], "reset_environment"):
            for im in self.inter_models:
                im.reset_environment(factor_params)

    def forward(self, x, valid=None, hard=False, ret_settings=None): # TODO: implement return settings
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        k,q = self.key_query_encoder(x)
        s = self.selection_network.forward(x)
        s = self.softmax(s)
        s = s.reshape(s.shape[0], -1, self.num_clusters) # (batch, num_keys, num_clusters)
        if hard: s = self.hard_inter_dist(s).sample()
        else: s = self.soft_inter_dist(self.selection_temperature, probs=s).rsample()
        # TODO: does not handle ret_settings logic properly
        inters = [self.inter_models[i].forward(x, valid=valid, ret_settings=ret_settings) for i in range(self.num_clusters)] # expects to get output of shape (batch, num_keys, num_clusters, num_instances)
        batch_size, nk, nq = s.shape[0], s.shape[1], inters[0].shape[-1] # TODO: relies on having at num_clusters > 3
        inters = torch.stack([get_passive_mask(batch_size, nk, nq, self.fp), get_active_mask(batch_size, nk, nq, self.fp)] + inters, dim=2)
        
        # weights the masks by the selection criteria
        inters = inters * s.unsqueeze(-1)
        v = inters.sum(dim=2).reshape(x.shape[0], -1)
        return v, [s] + v[1:]