import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Network.network import Network, network_type
from Network.network_utils import pytorch_model, get_acti
from Network.Dists.net_dist_utils import init_key_query, init_forward_args
import copy, time

# only handles FLAT inputs, no masking, to infer interactions without object based
# representations (or only pairwise). Used for granger analysis
class DiagGaussianForwardNetwork(Network):
    def __init__(self, args):
        super().__init__(args)

        self.fp = args.factor
        self.key_args, self.query_args, self.key_query_encoder = \
            init_key_query(args) # even if we don't embed, init a key_query that just slices
        mean_args = init_forward_args(args)
        mean_args.activation_final = "none"
        self.mean = network_type[args.net_type](mean_args)
        std_args = init_forward_args(args)
        std_args.activation_final = "none"
        self.std = network_type[args.net_type](std_args)
        self.model = [self.key_query_encoder, self.mean, self.std]

        self.base_variance = .01 # hardcoded based on normalized values, base variance 1% of the average variance

        self.pre_embed = args.pre_embed
        self.object_dim = args.factor.object_dim
        self.embed_dim = args.factor.embed_inputs

        self.train()
        self.reset_network_parameters()

    def reset_environment(self, factor_params):
        self.fp = factor_params
        self.key_query_encoder.reset_environment(factor_params)
        if hasattr(self.inter_models[0], "reset_environment"): 
            for im in self.inter_models:
                im.reset_environment(factor_params)
        if hasattr(self.mean, "reset_environment"):
            self.mean.reset_environment(factor_params)
            self.std.reset_environment(factor_params)

    def forward(self, x, m=None, valid =None, dist_settings=None, ret_settings=None):
        # keyword hyperparameters are used only for consistency with the mixture of experts model
        # x: batch_size, input_dim
        # m: batch_size, num_keys, num_queries
        # dist_settings: soft, mixed, flat, full
        # return settings: embedding, reconstruction, weights
        # returns (mean,var) tuple, mask (if generated) and info, which contains ret_settings defined values
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        # if pre_embed, assumes that x is already a tuple of key, query
        keys, queries = self.key_query_encoder(x) if not self.pre_embed else x # [batch size, embed dim, num keys], [batch size, embed dim, num queries]
        # merge mask and valid
        mean = self.mean(keys, queries, None, ret_settings=ret_settings)
        var = self.std(keys, queries, None, ret_settings=ret_settings)
        # merges the mean and variance masks, if necessary
        meanv = (torch.tanh(mean[0]))
        varv = (torch.sigmoid(var[0]) + self.base_variance)
        return (meanv, varv), m, (mean[1:], var[1:])
    
class InteractionNetwork(Network):
    def __init__(self, args):
        super().__init__(args)
        inter_args = copy.deepcopy(args)
        inter_args.num_outputs = 1
        inter_args.activation_final = "sigmoid"
        self.inter = network_type[args.net_type](inter_args)
        self.model = [self.inter]

        self.train()
        self.reset_network_parameters()
        
    def forward(self, x):
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        v = self.inter.forward(x)
        return v