import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Network.network import Network
from Network.net_types import network_type
from Network.network_utils import pytorch_model, get_acti
from Network.Dists.mask_utils import expand_mask, apply_symmetric
from Network.Dists.forward_mask import DiagGaussianForwardPadMaskNetwork
import copy, time


class DiagGaussianForwardMultiMaskNetwork(Network):
    '''
    Class for EM based algorithms, where there are num_masks forward networks but \
    one shared embedding network. set_index sets the index of the network to the \
    desired network to run at a given time, and self.forward_models contains the \
    forward models at indexes, for optimization
    '''
    def __init__(self, args):
        super().__init__(args)
        self.index = 0
        self.num_networks = args.multi.num_masks + 1

        self.fp = args.factor
        self.use_embedding = args.multi.use_embedding

        layers = list()
        if self.use_embedding:
             # we only have a key_query encoder if we are pre-embedding, otherwise it is handled in forward
            self.key_args, self.query_args, self.key_query_encoder = \
                init_key_query(args)
            args.pre_embed = True
            layers = [self.key_query_encoder]

        self.forward_models = list()
        for i in range(self.num_networks):
            self.forward_models.append(DiagGaussianForwardPadMaskNetwork(args))
        self.forward_models = nn.ModuleList(self.forward_models)

        self.model = layers + [self.forward_models]

        self.train()
        self.reset_network_parameters()
    
    def reset_environment(self, factor_params):
        self.fp = factor_params
        if hasattr(self, "key_query_encoder"): self.key_query_encoder.reset_environment(factor_params)
        for model in self.forward_models:
            model.reset_environment(class_index, num_objects, submodel_first)
        self.total_instances = num_objects

    def set_index(self, idx):
        self.index = idx
    
    def set_full(self):
        self.index = self.num_networks - 1 # the last network is the full mask  
    
    def reset_index(self, idx, optimizer_args, embedding_optimizer=False):
        self.index= idx
        self.forward_models[self.index].reset_network_parameters()
        plist = self.forward_models[self.index].parameters()
        if embedding_optimizer: plist+= self.embedding.parameters()
        optimizer = optim.Adam(plist,
                optimizer_args.lr, eps=optimizer_args.eps, betas=optimizer_args.betas, weight_decay=optimizer_args.weight_decay)
        return optimizer

    def get_queries(self, x):
        first_dim = 0 if self.symmetric_key_query else max(0, self.first_obj_dim)
        num_obj = int((x.shape[-1] - first_dim) // self.object_dim)
        queries = x[...,first_dim:].reshape(-1, num_obj, self.object_dim)
        return queries

    def forward(self, x, m=None, valid =None, dist_settings=None, ret_settings=None):
        # keyword hyperparameters are used only for consistency with the mixture of experts model
        # start = time.time()
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        if self.use_embedding: x = self.key_query_encoder(x)
        meanvar, m, info = self.forward_models[self.index](x, m=m, valid=valid, dist_settings=dist_settings, ret_settings=ret_settings)
        return meanvar, m
