import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Network.network import Network
from Network.net_types import network_type
from Network.network_utils import pytorch_model, get_acti, assign_distribution
from Network.Dists.mask_utils import expand_mask, get_hot_mask, get_active_mask, get_passive_mask, apply_symmetric
from Network.Dists.net_dist_utils import init_key_query, init_inter_args
import copy, time

class InteractionMaskNetwork(Network):
    def __init__(self, args):
        super().__init__(args)
        self.fp = args.factor
        self.key_args, self.query_args, self.key_query_encoder = \
            init_key_query(args) # even if we don't embed, init a key_query that just slices
        args.factor.key_dim, args.factor.query_dim = self.key_query_encoder.key_dim, self.key_query_encoder.query_dim
        inter_args = init_inter_args(args)
        self.inter = network_type[inter_args.net_type](inter_args)
        self.use_softmax = args.activation_final == 'softmax'
        self.softmax = nn.Softmax(dim=-1)
        self.model = [self.inter]
        self.pre_embed = args.inter_net.shared_encoding

        self.train()
        self.reset_network_parameters()

    def reset_environment(self, factor_params):
        self.fp = factor_params
        self.key_query_encoder.reset_environment(factor_params)
        if hasattr(self.inter, "reset_environment"):
            self.inter.reset_environment(factor_params)

    def forward(self, x, valid=None, ret_settings=None, grad_settings=[]): # TODO: implement return settings like returning embeddings
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        if "input" in grad_settings: x.requires_grad = True
        keys, queries = self.key_query_encoder(x) if not self.pre_embed else x 
        if "embed" in grad_settings: keys.requires_grad, queries.requires_grad = True, True
        valid = pytorch_model.wrap(valid, cuda=self.iscuda) if valid is not None else valid
        validv = self.key_query_encoder.slice_masks(valid, x.shape[0], keys.shape[1], queries.shape[1])
        v = self.inter(keys, queries, mask=validv, ret_settings=ret_settings)
        vv = v[0]
        vv = vv.reshape(-1, keys.shape[1], queries.shape[1])
        if self.use_softmax:
            vv = self.softmax(vv)
        else: vv = F.sigmoid(vv)
        return vv, (x, keys, queries, v[1:])


