import types
from lib.utils import *
from lib.causal_flow_tracer import return_forward_method_dict
from lib.search_func import block_initializer
from lib.causal_flow_tracer import CausalFlowTracer
from lib.search_func import top_down_causal_flow_trace


def compute_faithfulness(args):
    curr_meta_arg = {"total_block_num": len(args.mt.model.blocks)}

    for bidx in range(curr_meta_arg["total_block_num"]):
        curr_block = args.mt.model.blocks[bidx]
        curr_block = block_initializer(args, curr_block)

    flow_tracer = CausalFlowTracer()
    flow_tracer, answer, curr_total_token_num, curr_total_block_num, normal_inp = predict_from_normal_and_noise_input(
    images, flow_tracer, labels, logger,
    mt, n_lev=None, end_symbol=None, out_num=1, 
    model_name=args.model,
    num_normal_sample=args.num_noise_sample, 
    num_noise_sample=args.num_noise_sample, 
    noise_type=args.noise_type)

    curr_meta_arg["total_block_num"] = curr_total_block_num
    curr_meta_arg["total_token_num"] = curr_total_token_num
    curr_meta_arg["mt"] = args.mt
    curr_meta_arg["logger"] = args.logger

    stop_flag = False
    stop_flag = top_down_causal_flow_trace(
        args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer)
        
    scores_normal = flow_tracer.scores_normal

    original_top_1_p = torch.max(scores_normal, dim=0)
    original_answer_t = torch.max(scores_normal, dim=0).indices.unsqueeze(0)

    for bidx in range(args.num_layers):
        curr_block = args.mt.model.blocks[bidx]
        curr_block = block_initializer(args, curr_block)
        
        fwd_method_dict = return_forward_method_dict(args)

        curr_block.forward = types.MethodType(fwd_method_dict['intervention_forward'], curr_block)
        curr_block.mlp.forward = types.MethodType(fwd_method_dict['custom_mlp_forward'], curr_block.mlp)
        curr_block.attn.forward = types.MethodType(fwd_method_dict['custom_attn_forward'], curr_block.attn)

        # curr_block.anonymoususer_trace_mode = True
        # curr_block.anonymoususer_cond = "zero_ablation"
        # curr_block.anonymoususer_curr_subset = args.causal_subset[str(bidx)]


        curr_block.anonymoususer_trace_mode = True
        curr_block.anonymoususer_cond = "path"
        curr_block.anonymoususer_curr_subset = args.causal_subset[str(bidx)]
        curr_block.anonymoususer_corrupted_feats = curr_meta_arg["detailed_flow_tracer"][bidx]["corrupted"]

    with torch.no_grad():
        ablated_top_1_token, ablated_top_1_p, scores_ablated = predict_from_input(
            args.mt.model, normal_inp, multipred=(args.out_num!=1), end_symbol=args.end_symbol, use_mean=True)

        scores_ablated = torch.softmax(scores_ablated, dim=1)[0] ####hummm....

        ablated_top_1_p = torch.max(scores_ablated)
        ablated_answer_t = torch.max(scores_ablated, dim=0).indices.unsqueeze(0)
        original_top_1_token_p_ablated = scores_ablated[original_answer_t]
    
    data = {
        "original_top_1_token": original_top_1_token,
        "original_top_1_p": original_top_1_p.item(),
        "ablated_top_1_token": ablated_top_1_token,
        "ablated_top_1_p": ablated_top_1_p.item(),
        "original_top_1_token_p_ablated": original_top_1_token_p_ablated.item(),
    }
    return data