import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Network.network import Network
from Network.net_types import network_type
from Network.network_utils import pytorch_model, get_acti
import copy, time
from Network.Dists.mask_utils import expand_mask, apply_probabilistic_mask, count_keys_queries, get_hot_mask, get_active_mask, get_passive_mask, apply_symmetric, MASK_ATTENTION_TYPES
from Network.Dists.net_dist_utils import init_key_query, init_forward_args

class DiagGaussianForwardPadHotNetwork(Network):
    def __init__(self, args):
        super().__init__(args)
        

        self.fp = args.factor
        # needs to have: passive_mask, first_obj_dim (if pair), object_dim, class_index, num_objects
        self.dist = args.dist_params
        self.cluster = args.cluster_params

        # iitialize 
        layers = list()
        self.key_args, self.query_args, self.key_query_encoder = \
            init_key_query(args) # even if we don't embed, init a key_query that just slices
        layers = [self.key_query_encoder]

        # inter models must operate by pointnet principles to be instance invariant
        inter_args = copy.deepcopy(args.cluster_params.inter_args)
        self.inter_models = nn.ModuleList([network_type[inter_args.net_type](inter_args) for i in range(self.num_clusters - 2)]) # two clusters reserved, one for passive and one for full

        # forward networks, must be a factor or pairnet
        forward_args = init_forward_args(args)
        # network_type must be a network that handles keys, values and valid
        self.means = nn.ModuleList([network_type[args.net_type](forward_args) for i in range(self.num_clusters)])
        self.stds = nn.ModuleList([network_type[args.net_type](forward_args) for i in range(self.num_clusters)])

        layers = [self.means, self.stds, self.inter_models]
        self.model = layers
        self.test = False
        self.base_variance = .01 # hardcoded based on normalized values, base variance 1% of the average variance
        self.mask_dim = args.factor.total_instances # does not handle arbitary number of instances

        self.train()
        self.reset_network_parameters()

    def reset_environment(self, factor_params):
        self.fp = factor_params
        self.key_query_encoder.reset_environment(factor_params)
        if hasattr(self.inter_models[0], "reset_environment"): 
            for im in self.inter_models:
                im.reset_environment(factor_params)
        if hasattr(self.means[0], "reset_environment"):
            for m in self.means:
                m.reset_environment(factor_params)
            for s in self.stds:
                s.reset_environment(factor_params)


    def get_inter_mask(self, i, x, num_keys, soft, mixed, flat):
        inter = self.inter_models[i](x) # [batch, num_keys, num_queries]
        return apply_probabilistic_mask(inter, inter_dist=self.dist.inter_dist if ((not soft) or (soft and mixed)) else None,
                                        relaxed_inter_dist=self.dist.relaxed_inter_dist if (soft and not mixed) else None, 
                                        mixed=mixed, test=self.test if flat else None, dist_temperature=self.dist.dist_temperature, 
                                        revert_mask=False)

    def compute_cluster_masks(self, x, m, num_keys, num_queries, soft=False, mixed=False, flat=False):
        # returns all the interactions masks computed by all the models, as well as the masks weighted by their true values
        passive_masks = [get_passive_mask(x.shape[0], num_keys, num_queries, self.op)] # broadcast to batch size
        active_masks = [get_active_mask(x.shape[0], num_keys, num_queries, self.op)]

        all_masks = torch.stack(passive_masks + active_masks + [self.get_inter_mask(i, x, num_keys, soft, mixed, flat) for i in range(self.num_clusters - 2)], axis=0)
        # all masks shape: num_clusters, batch size, num_keys, num_queries 
        m = m.reshape(x.shape[0], num_keys, self.num_clusters).transpose(0,2).transpose(1,2).unsqueeze(-1) # flip clusters to the front, flip keys and num_batch add a dimension for queries broadcasting
        return all_masks, (all_masks * m).sum(0).reshape(x.shape[0], -1)

    def compute_clusters(self, cluster_nets, keys, queries, masks, m, valid=None, ret_settings=[]):
        # keys of shape n_batch, n_keys, d_keys
        # queries of shape n_batch n_queries d_queries
        # masks pf shape n_batch n_keys n_queries
        # m of shape n_batch n_keys n_cluster
        total_out = list()
        for i in range(self.num_clusters): # we could probably do this in parallel
            total_out.append(cluster_nets[i](keys, queries, masks[i], valid=valid, ret_settings=ret_settings).reshape(keys.shape[0], keys.shape[-1],-1)) # [batch, num_keys, single_obj_dim] x num_clusters
        return (torch.stack(total_out, dim=-1) * m.unsqueeze(-2)).sum(-1).reshape(keys.shape[0], -1) # [batch size, n_keys * single_obj_dim]

    def forward(self, x, m=None, valid=None, dist_settings=[], ret_settings=[]):
        # x: batch size, single_obj_dim * num_keys (first_obj_dim) + object_dim * num_queries
        # m: batch size, num_keys * num_clusters
        # print("soft, mixed, flat, full", soft, mixed, flat, full)
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        keys, queries, _ = self.key_query_encoder(x) # [batch size, embed dim, num keys], [batch size, embed dim, num queries]
        inter_masks, total_mask = self.compute_cluster_masks(x, m, keys.shape[-1], queries.shape[-1], soft='soft' in dist_settings, mixed='mixed' in dist_settings, flat='flat' in dist_settings, valid=valid)
        # in the full case, the cluster heads are established as the last layer
        mean = self.compute_clusters(self.means, keys, queries, inter_masks, m, valid=valid, ret_settings=ret_settings)
        var = self.compute_clusters(self.stds, keys, queries, inter_masks, m, valid=valid, ret_settings=ret_settings)
        return (torch.tanh(mean[0]), torch.sigmoid(var[0]) + self.base_variance), total_mask, (mean[1:], var[1:])