from Model.InferenceModule.inference_module import InferenceModule
import numpy as np
from tianshou.data import Batch
from Model.InferenceModule.module_utils import trace_log_probs

class PairModule(InferenceModule):
    def __init__(self, args, extractor, pair, forward_dist, pair_model, full_model, all_model):
        super().__init__(args, extractor)
        self.mp = args.inter
        self.name = pair
        self.source, self.target = pair.split('->')
        self.source = self.source.split("|")
        self.use_full_as_pair = self.mp.use_full_as_pair
        self.use_all_as_single = self.mp.use_all_as_single
        if self.use_all_as_single:
            self.model = all_model
        else:
            if self.use_full_as_pair:
                self.model = full_model
            else:
                self.model = pair_model

        self.forward_dist = forward_dist
        self.init_optimizer(args)

    def __call__(self, batch, valid, extractor, normalizer, additional=[], grad_settings=[], log_batch=[], keep_invalid=False, keep_all=False):
        # TODO: the logic is very similar to single_passive_module
        omit_flags = self.get_omit(batch, keep_all=keep_all, keep_invalid=keep_invalid, use_name=self.target)
        batch = batch[omit_flags]
        if self.use_all_as_single:
            key_state = batch.target
        else: # note that single passive doesn't use the key state. shouldn't hurt to pass it in
            key_state = extractor.get_named_target(batch.target, names=self.target)
        if self.use_full_as_pair or self.use_all_as_single:
            query_state = batch.obs
        else: # uses the target since there is only one query state
            query_state = extractor.get_named_obs(batch.obs, names=self.source).reshape(len(batch.obs), -1)
        valid = valid[omit_flags]

        if self.use_all_as_single:
            mask = np.eye(extractor.num_objects)
        else:
            mask = np.zeros(extractor.num_objects)
            mask[extractor.get_index(self.source)] = 1
            mask[extractor.get_index(self.target)] = 1
        # run the model to get return values
        params, mask, info = self.model(np.concatenate([key_state, query_state], axis=-1), m=mask, valid = valid, dist_settings=['mixed'], ret_settings=additional, grad_settings=grad_settings)
        pairwise_input, keys, queries, info1,info2 = info
        info = list(zip(info1,info2))

        if self.use_all_as_single:
            # assumes that values are of shape [batch, num_keys, ...], so index the appropriate key
            params, mask, info = self.single_index_all(self.target, extractor, params, mask, info)


        target, log_probs = self._target_dists(batch, params)
        result = Batch(target=target, params=params, mask=mask, log_probs=log_probs, omit_flags=omit_flags, pairwise_input=pairwise_input)
        result.trace_log_probs = trace_log_probs(extractor.num_objects, result.log_probs, batch, idx=extractor.get_index(self.target))
        for i, aname in enumerate(additional):
            result[aname] = info[i]
        for k in log_batch:
            result[k] = batch[k]
        return result