from ActualCausal.Inference.compute_inference import evaluate_inference
import numpy as np
import cv2
from Network.network_utils import pytorch_model
from ACState.object_dict import ObjDict

def render_idxes(env, name_idx, model, batch, idxes):
    for eidx in idxes:
        factored_state = model.extractor.get_factored_from_obs(batch.obs[eidx-2,])
        env.set_from_factored_state(factored_state)
        for idx in range(eidx-2, eidx +3):
            print(np.concatenate([batch.trace[idx, name_idx], batch.inter_masks[idx, name_idx], batch.bin_error[idx]], axis=-1))
            cv2.imshow("frame", env.render())
            cv2.waitKey(100)


def render_key_states(args, model, buffer, env, test=False):
    # Render states with high trace error, splitting false positives and false negatives
    # also visualizes some non-error states
    batch, idxes = buffer.sample(np.arange(1000)) # don't sample too many states
    result = evaluate_inference(0, args, 
                                ObjDict({"infer_num": args.infer.infer_num, "mask_mode": args.infer.train_mask_mode}), 
                                model, batch, test=test)
    batch = batch[result.omit_flags] # match batch indexes with result idxes
    
    for infer_type in args.infer.render_eval:
        for name in args.infer.infer_names:
            name_idx = model.extractor.get_index(name)
            # TODO: this removes the passive flag, which depending might not be correct
            bin_error = pytorch_model.unwrap(result[infer_type])
            bin_error[...,name_idx] = 0
            # get the indexes of interest based on bin_error
            high_FP = (np.max(bin_error, axis=-1) > args.infer.render_threshold).nonzero()
            high_FN = (np.min(bin_error, axis=-1) < -args.infer.render_threshold).nonzero()
            passive_mask = np.zeros((args.num_objects,))
            passive_mask[name_idx] = 1
            sampled_traces = (batch.trace[:, name_idx] - np.expand_dims(passive_mask, axis=0)).sum(axis=-1).nonzero()

            print("Rendering false positives")
            render_idxes(high_FP[0])

            print("Rendering false negatives")
            render_idxes(high_FN[0])

            print("Rendering traces")
            render_idxes(sampled_traces[0])
