from Model.InferenceModule.inference_module import InferenceModule
from tianshou.data import Batch
import numpy as np
from Model.InferenceModule.module_utils import apply_mask
from Network.network_utils import pytorch_model
import time
from Model.InferenceModule.module_utils import trace_log_probs
from Network.General.Factor.Pair.pair import merge_key_queries

class FullModule(InferenceModule):
    def __init__(self, args, extractor, name, dists, single_inter_model, inter_model, full_model, all_model):
        super().__init__(args, extractor)
        self.mp = args.inter
        self.mask_args = args.inter.masking
        self.dist_settings = [self.mask_args.masking_form]
        self.name = name
        if self.mp.use_all_as_single:
            self.inter_model = inter_model
            self.model = all_model
        else:
            self.inter_model = single_inter_model
            self.model = full_model

        self.dists = dists
        self.forward_dist = self.dists.forward
        self.init_optimizer(args)

    def __call__(self, batch, valid, extractor, normalizer, additional=[], grad_settings=[], log_batch=[], full=False, keep_invalid=False, keep_all=False, probs=False):
        # omit states where the name is not valid, or at dones
        start = time.time()
        omit_flags = self.get_omit(batch, keep_all=keep_all, keep_invalid=keep_invalid, use_name=self.name)
        if not keep_all: batch = batch[omit_flags]
        if self.mp.use_all_as_single:
            key_state = batch.target
        else: # note that single passive doesn't use the key state. shouldn't hurt to pass it in
            key_state = extractor.get_named_target(batch.target, names=self.name)
        query_state = batch.obs
        valid = valid[omit_flags]
        key_query_state = np.concatenate([key_state, query_state], axis=-1)

        result = Batch()
        result.utrace = batch.trace[:, self.extractor.get_index([self.name])]
        if full:
            result.mask = pytorch_model.wrap(np.ones(extractor.num_objects), cuda=self.iscuda)
            result.mask_logits = pytorch_model.wrap(np.ones(extractor.num_objects), cuda=self.iscuda)
        else:
            result.raw_mask_logits, info = self.inter_model(key_query_state, valid=valid, ret_settings=additional, grad_settings=grad_settings)
            result.mask_logits = result.raw_mask_logits * pytorch_model.wrap(valid, cuda=self.iscuda).unsqueeze(1)
            result.mask_add = Batch()
            result.mask_add.mask_input, keys, queries, info = info
            for i, aname in enumerate(additional):
                result.mask_add[aname] = info[i]
        # print("evaluated_mask", full, result[0].mask_logits, valid[0], key_query_state[0])

        for k in log_batch:
            result[k] = batch[k]
        result.omit_flags = omit_flags

        if probs: return result # if just need the probabilities, shortcut computation

        # run the model for each kind of masking, using the same logits. result.params, mask, info and target, dist, log_probs
        # are only applied to the one used by self.mask_setting
        all_settings = ['flat', 'soft', 'hard', 'mixed']
        for masking in all_settings:
            if masking in additional or masking == self.mask_args.masking_form:
                result[masking] = Batch()
                mask = apply_mask(self.mask_args,self.dists, result.mask_logits, soft=masking=='soft' or masking == 'mixed', flat=masking=='flat', mixed=masking=='mixed', test=self.test, iscuda=self.iscuda)
                result[masking].params, result[masking].mask, info = self.model(key_query_state, m=mask, valid = valid, dist_settings=self.dist_settings, ret_settings=additional, grad_settings=grad_settings)
                result[masking].full_active_input, keys, queries, info1, info2 = info
                info = list(zip(info1, info2))
                if self.mp.use_all_as_single:
                    # assumes that values are of shape [batch, num_keys, ...], so index the appropriate key
                    result[masking].params, result[masking].mask, info = self._single_index_all(self.target, extractor, result[masking].params, result[masking].mask, info)
                result[masking].target, result[masking].log_probs = self._target_dists(batch, result[masking].params)
                result[masking].trace_log_probs = trace_log_probs(extractor.num_objects, result[masking].log_probs, batch, idx=extractor.get_index(self.name))
                if "masked_pre_embeddings" in additional: # TODO: not implemented for all
                    result[masking]["masked_pre_embeddings"] = merge_key_queries(keys, queries, mask, append_keys=False, append_broadcast_mask = self.model.append_broadcast_mask, append_mask=False)
                for i, aname in enumerate(additional):
                    # note that there will be a bug if info is not a tensor
                    if masking in result:
                        result[masking][aname] = info[i]
            if masking == self.mask_args.masking_form: 
                for name in result[masking].keys():
                    result[name] = result[masking][name]
        # print("results", time.time() - start)
        return result