import types
import time
import pickle
import itertools
from lib.utils import *
from lib.causal_flow_tracer import *



def block_initializer(args, curr_block, symbolic_name="anonymoususer"):
    forward_method_dict = return_forward_method_dict(args)

    ### INITIALIZE
    # restore the original function
    curr_block.forward = types.MethodType(forward_method_dict["org_block_forward"], curr_block)
    curr_block.mlp.forward = types.MethodType(forward_method_dict["org_mlp_forward"], curr_block.mlp)
    curr_block.attn.forward = types.MethodType(forward_method_dict["org_attn_forward"], curr_block.attn)

    # remove previous used values
    dellist = [k for k in curr_block.__dict__.keys() if "anonymoususer" in k]
    _ = [ delattr(curr_block, k) for k in dellist]
    dellist = [k for k in curr_block.mlp.__dict__.keys() if "anonymoususer" in k]
    _ = [ delattr(curr_block.mlp, k) for k in dellist]
    dellist = [k for k in curr_block.attn.__dict__.keys() if "anonymoususer" in k]
    _ = [ delattr(curr_block.attn, k) for k in dellist]
    ### INITIALIZE
    return curr_block

def get_detailed_corrupted_flow_tracer_one_block(
        args, curr_meta_arg, flow_tracer, detailed_flow_tracer_idx):
    fwd_method_dict = return_forward_method_dict(args)
    # just for saving corrupted features
    detailed_flow_tracer = CausalFlowTracer()
    detailed_flow_tracer.init_node()

    # Noise RUN
    for bidx in range(curr_meta_arg["total_block_num"]):
        curr_block = curr_meta_arg["mt"].model.blocks[bidx]
        curr_block = block_initializer(args, curr_block)

        # change target function
        if bidx < detailed_flow_tracer_idx-1:
            curr_block.forward = types.MethodType(fwd_method_dict["pass_block_forward"], curr_block)
            curr_block.anonymoususer_feats_normal = flow_tracer.feats_normal[bidx]
            curr_block.anonymoususer_feats_k_normal = flow_tracer.feats_k_normal[bidx]
            curr_block.anonymoususer_feats_v_normal = flow_tracer.feats_v_normal[bidx]
        elif bidx == detailed_flow_tracer_idx-1:
            # prepare corrupted features
            curr_block.forward = types.MethodType(fwd_method_dict["pass_corrupted_block_forward"], curr_block)

            curr_block.anonymoususer_feats_corrupted = flow_tracer.feats_corrupted[bidx]
            curr_block.anonymoususer_feats_k_corrupted = flow_tracer.feats_k_corrupted[bidx]
            curr_block.anonymoususer_feats_v_corrupted = flow_tracer.feats_v_corrupted[bidx]
        elif bidx == detailed_flow_tracer_idx:
            if bidx == 0: 
                # Since the process of preparing the corrupted feature is missing when bidx==0, replace it with this line instead.
                curr_block.anonymoususer_feats_corrupted_init = flow_tracer.feats_corrupted_init

            # save corrupted features
            curr_block.forward = types.MethodType(fwd_method_dict["intervention_forward"], curr_block)
            curr_block.mlp.forward = types.MethodType(fwd_method_dict["custom_mlp_forward"], curr_block.mlp)
            curr_block.attn.forward = types.MethodType(fwd_method_dict["custom_attn_forward"], curr_block.attn)

            curr_block.anonymoususer_flow_tracer = detailed_flow_tracer
            curr_block.anonymoususer_save_mode = True
        else:
            curr_block.forward = types.MethodType(fwd_method_dict["pass_block_forward"], curr_block)

            curr_block.anonymoususer_feats_normal = flow_tracer.feats_normal[bidx]
            curr_block.anonymoususer_feats_k_normal = flow_tracer.feats_k_normal[bidx]
            curr_block.anonymoususer_feats_v_normal = flow_tracer.feats_v_normal[bidx]

    with torch.no_grad():
        _, _, _ = predict_from_input(curr_meta_arg["mt"].model, curr_meta_arg["inp"], multipred=(args.out_num!=1), end_symbol=args.end_symbol)

    # Restore orignal module
    for bidx in range(curr_meta_arg["total_block_num"]):
        curr_block = curr_meta_arg["mt"].model.blocks[bidx]
        curr_block = block_initializer(args, curr_block)

    return detailed_flow_tracer


def cause_edge_classifier(
    args, curr_meta_arg, flow_tracer
):  

    cond_bool = []
    raw_outs = []
    preds = []
    ps = []
    for cond in ["counterfactual", "contingency"]:
        if args.efficient_mode:
            if cond == "counterfactual":
                cond_bool.append(True)
                preds.append(copy.deepcopy(curr_meta_arg["normal_pred"]))
                ps.append(copy.deepcopy(curr_meta_arg["normal_p"]))
                continue

       
        for bidx in range(curr_meta_arg["total_block_num"]):
            curr_block = curr_meta_arg["mt"].model.blocks[bidx]
            curr_block = block_initializer(args, curr_block)

            fwd_method_dict = return_forward_method_dict(args)

            # change target function
            if bidx < curr_meta_arg["curr_block_idx"]: ###using it!
                curr_block.forward = types.MethodType(fwd_method_dict["pass_block_forward"], curr_block)
                curr_block.anonymoususer_feats_normal = flow_tracer.feats_normal[bidx]
                curr_block.anonymoususer_feats_k_normal = flow_tracer.feats_k_normal[bidx]
                curr_block.anonymoususer_feats_v_normal = flow_tracer.feats_v_normal[bidx]
            elif bidx == curr_meta_arg["curr_block_idx"]:
                curr_block.forward = types.MethodType(fwd_method_dict["intervention_forward"], curr_block)
                curr_block.mlp.forward = types.MethodType(fwd_method_dict["custom_mlp_forward"], curr_block.mlp)
                curr_block.attn.forward = types.MethodType(fwd_method_dict["custom_attn_forward"], curr_block.attn)

                curr_block.anonymoususer_curr_block_idx = curr_meta_arg["curr_block_idx"] 
                curr_block.anonymoususer_normal_feats = flow_tracer.feats_normal
                curr_block.anonymoususer_trace_mode = True
                curr_block.anonymoususer_cond = cond
                curr_block.anonymoususer_corrupted_feats = curr_meta_arg["detailed_flow_tracer"][bidx]["corrupted"]
                curr_block.anonymoususer_curr_subset = curr_meta_arg["curr_subset"]

            elif bidx > curr_meta_arg["curr_block_idx"]: #### for subset unique-path search, forard ct and make condition in here. 
                curr_block.forward = types.MethodType(fwd_method_dict["intervention_forward"], curr_block)
                curr_block.mlp.forward = types.MethodType(fwd_method_dict["custom_mlp_forward"], curr_block.mlp)
                curr_block.attn.forward = types.MethodType(fwd_method_dict["custom_attn_forward"], curr_block.attn)

                curr_block.anonymoususer_curr_block_idx = curr_meta_arg["curr_block_idx"] 
                curr_block.anonymoususer_trace_mode = True
                curr_block.anonymoususer_cond = "path"
                curr_block.anonymoususer_corrupted_feats = curr_meta_arg["detailed_flow_tracer"][bidx]["corrupted"]
                curr_block.anonymoususer_curr_subset = tuple(set().union(*flow_tracer.traced_paths[bidx]))
              
        with torch.no_grad():
            curr_pred, curr_p, raw_out = predict_from_input(
                curr_meta_arg["mt"].model, curr_meta_arg["inp"], multipred=(args.out_num!=1), end_symbol=args.end_symbol, use_mean=True)

        preds.append(copy.deepcopy(curr_pred))
        ps.append(copy.deepcopy(curr_p))
        # raw_outs.append(raw_out)
        curr_pred_unq = torch.unique(curr_pred)
        change_bool = curr_pred_unq==curr_meta_arg["normal_pred"]

        if (change_bool.sum()!=change_bool.shape[0]).item():
            if cond == "counterfactual":
                cond_bool.append(True)
            elif cond == "contingency":
                cond_bool.append(False)
        else:
            if cond == "counterfactual":
                cond_bool.append(False)
            elif cond == "contingency":
                cond_bool.append(True)
    cond_bool = torch.tensor(cond_bool)

    # print(f"cond_bool: {cond_bool}")
    actual_cause_bool = (cond_bool==True).all().item()
    if args.slightly_quiet_mode:
        if actual_cause_bool:
            log_line = "\t\t ({})[Subset: {}] | org:{}({}) path:CF-{}({}), CT-{}({})".format(
                "Causal" if actual_cause_bool else "Non-causal",
                curr_meta_arg["curr_subset"],
                curr_meta_arg["normal_pred"].item(), 
                round(curr_meta_arg["normal_p"].item(), 5), 
                preds[0].item() if args.efficient_mode is False else "",
                round(ps[0].item(), 5)if args.efficient_mode is False else "PASS",
                preds[1].item(), 
                round(ps[1].item(), 5),
            )
            curr_meta_arg["logger"].info(log_line)
    else:
        log_line = "\t\t ({})[Subset: {}] | org:{}({}) path:CF-{}({}), CT-{}({})".format(
            "Causal" if actual_cause_bool else "Non-causal",
            curr_meta_arg["curr_subset"],
            curr_meta_arg["normal_pred"].item(), 
            round(curr_meta_arg["normal_p"].item(), 5), 
            preds[0].item() if args.efficient_mode is False else "",
            round(ps[0].item(), 5)if args.efficient_mode is False else "PASS",
            preds[1].item(), 
            round(ps[1].item(), 5),
        )
        curr_meta_arg["logger"].info(log_line)


    decision_save = {
        "CF":  [str(preds[0].item()), str(ps[0].item())] if args.efficient_mode is False else ["None", "None"],
        "CT":  [str(preds[1].item()), str(ps[1].item())]
    }
    
    return actual_cause_bool, cond_bool, decision_save

def cause_edge_classifier_for_evaluation(
    args, curr_meta_arg, flow_tracer
):  

    cond_bool = []
    raw_outs = []
    preds = []
    ps = []
    for cond in ["counterfactual", "contingency"]:
        if args.efficient_mode:
            if cond == "counterfactual":
                cond_bool.append(True)
                preds.append(copy.deepcopy(curr_meta_arg["normal_pred"]))
                ps.append(copy.deepcopy(curr_meta_arg["normal_p"]))
                continue

       
        for bidx in range(curr_meta_arg["total_block_num"]):
            curr_block = curr_meta_arg["mt"].model.blocks[bidx]
            curr_block = block_initializer(args, curr_block)

            fwd_method_dict = return_forward_method_dict(args)

            curr_block.forward = types.MethodType(fwd_method_dict["intervention_forward"], curr_block)
            curr_block.mlp.forward = types.MethodType(fwd_method_dict["custom_mlp_forward"], curr_block.mlp)
            curr_block.attn.forward = types.MethodType(fwd_method_dict["custom_attn_forward"], curr_block.attn)

            curr_block.anonymoususer_curr_block_idx = curr_meta_arg["curr_block_idx"] 
            curr_block.anonymoususer_trace_mode = True
            curr_block.anonymoususer_cond = "path"
            curr_block.anonymoususer_corrupted_feats = curr_meta_arg["detailed_flow_tracer"][bidx]["corrupted"]
            curr_block.anonymoususer_curr_subset = tuple(set().union(*flow_tracer.traced_paths[bidx]))
              
        with torch.no_grad():
            curr_pred, curr_p, raw_out = predict_from_input(
                curr_meta_arg["mt"].model, curr_meta_arg["inp"], multipred=(args.out_num!=1), end_symbol=args.end_symbol, use_mean=True)

        preds.append(copy.deepcopy(curr_pred))
        ps.append(copy.deepcopy(curr_p))
        # raw_outs.append(raw_out)
        curr_pred_unq = torch.unique(curr_pred)
        change_bool = curr_pred_unq==curr_meta_arg["normal_pred"]

        if (change_bool.sum()!=change_bool.shape[0]).item():
            if cond == "counterfactual":
                cond_bool.append(True)
            elif cond == "contingency":
                cond_bool.append(False)
        else:
            if cond == "counterfactual":
                cond_bool.append(False)
            elif cond == "contingency":
                cond_bool.append(True)
    cond_bool = torch.tensor(cond_bool)

    # print(f"cond_bool: {cond_bool}")
    actual_cause_bool = (cond_bool==True).all().item()
    if args.slightly_quiet_mode:
        if actual_cause_bool:
            log_line = "\t\t ({})[Subset: {}] | org:{}({}) path:CF-{}({}), CT-{}({})".format(
                "Causal" if actual_cause_bool else "Non-causal",
                curr_meta_arg["curr_subset"],
                curr_meta_arg["normal_pred"].item(), 
                round(curr_meta_arg["normal_p"].item(), 5), 
                preds[0].item() if args.efficient_mode is False else "",
                round(ps[0].item(), 5)if args.efficient_mode is False else "PASS",
                preds[1].item(), 
                round(ps[1].item(), 5),
            )
            curr_meta_arg["logger"].info(log_line)
    else:
        log_line = "\t\t ({})[Subset: {}] | org:{}({}) path:CF-{}({}), CT-{}({})".format(
            "Causal" if actual_cause_bool else "Non-causal",
            curr_meta_arg["curr_subset"],
            curr_meta_arg["normal_pred"].item(), 
            round(curr_meta_arg["normal_p"].item(), 5), 
            preds[0].item() if args.efficient_mode is False else "",
            round(ps[0].item(), 5)if args.efficient_mode is False else "PASS",
            preds[1].item(), 
            round(ps[1].item(), 5),
        )
        curr_meta_arg["logger"].info(log_line)


    decision_save = {
        "CF":  [str(preds[0].item()), str(ps[0].item())] if args.efficient_mode is False else ["None", "None"],
        "CT":  [str(preds[1].item()), str(ps[1].item())]
    }
    
    return actual_cause_bool, cond_bool, decision_save

def minimality_search(
    args, curr_meta_arg, flow_tracer):  
    curr_meta_arg["logger"].info("[MinSearch Start]")


    detailed_flow_tracer = {}
    for detailed_flow_tracer_idx in range(
            curr_meta_arg["curr_block_idx"], curr_meta_arg["total_block_num"], 1):
        curr_detailed_corrupted_flow_tracer = get_detailed_corrupted_flow_tracer_one_block(
            args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer,
            detailed_flow_tracer_idx=detailed_flow_tracer_idx)
        
        detailed_flow_tracer.update({
            detailed_flow_tracer_idx: {"corrupted": curr_detailed_corrupted_flow_tracer}
            })

    curr_meta_arg["detailed_flow_tracer"] = detailed_flow_tracer


    normal_p, normal_pred = torch.max(flow_tracer.scores_normal, dim=0)
    curr_meta_arg["normal_p"] = normal_p
    curr_meta_arg["normal_pred"] = normal_pred

    # Option 1: Get n_head directly from the attention module
    n_head = curr_meta_arg['mt'].model.blocks[0].attn.num_heads
    curr_meta_arg["num_path_node"] = n_head * 2 + 2

    # Option 2: If num_heads isn't accessible, you could pass it as an argument
    # curr_meta_arg["num_path_node"] = args.num_heads * 2 + 2


    Xs_index = np.arange(curr_meta_arg["num_path_node"])
    min_search_memory = {}
    selected_subset = []
    step_subset = {}
    step_subset_cond_bool = {}
    stop_step = curr_meta_arg["num_path_node"]

    if args.debug_min_search:
        curr_meta_arg["logger"].info("\t [NOTE] This is the debug mode of Min. Search. It will pass 1~{} steps".format((curr_meta_arg["num_path_node"]-3)-1))

    minsear_iter = sorted(np.arange(curr_meta_arg["num_path_node"])+1)
    for curr_step in minsear_iter:
        
        if args.debug_min_search:
            if curr_step < (curr_meta_arg["num_path_node"]-3):
                continue

        curr_step_subset = list(itertools.combinations(Xs_index, curr_step))

        if args.subset_search == "minimality":
            subset_candidate = exclude_subsets(curr_step_subset, selected_subset, curr_meta_arg["logger"])
        else:
            subset_candidate = copy.deepcopy(curr_step_subset)
        
        if len(subset_candidate) == 0:
            stop_step = curr_step-1

        buffer_cond_bool = []
        if len(subset_candidate) != 0:
            curr_meta_arg["logger"].info("\t MinSearch Step: ({}/{}) -> Target Num.: {}".format(curr_step, max(minsear_iter), len(subset_candidate) ))
        if args.detailed_save:
            save_curr_step_causal_decision = {}
            save_curr_step_causal_subset = {}
            save_curr_step_causal_index = 0

        curr_meta_arg["curr_step"] = curr_step
        for subidx, curr_subset in enumerate(subset_candidate):
            
            curr_meta_arg["curr_subset"] = curr_subset

            if args.evaluation_mode:
                actual_cause_bool, cond_bool, decision_save = cause_edge_classifier_for_evaluation(
                    args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer
                )
            else:
                actual_cause_bool, cond_bool, decision_save = cause_edge_classifier(
                    args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer
                )

            buffer_cond_bool.append(cond_bool.tolist())

            if actual_cause_bool:
                selected_subset.append(curr_subset)
                if args.detailed_save:
                    save_curr_step_causal_decision.update({save_curr_step_causal_index: decision_save})
                    save_curr_step_causal_subset.update({save_curr_step_causal_index: list(map(str, curr_subset))})
                    save_curr_step_causal_index+=1
        step_subset[curr_step] = subset_candidate
        step_subset_cond_bool[curr_step] = buffer_cond_bool

        if args.detailed_save:
            save_curr_step_data = {
                "org_decision": [curr_meta_arg["normal_pred"].item(), curr_meta_arg["normal_p"].item()],
                "subset_decision": save_curr_step_causal_decision,
                "subset": save_curr_step_causal_subset
            }
            
            detailed_curr_save_dir = os.path.join(args.curr_save_dir, "detailed")
            os.makedirs(detailed_curr_save_dir, exist_ok=True)
            detailed_block_dir = os.path.join(detailed_curr_save_dir, "block_{:02d}".format(curr_meta_arg["curr_block_idx"]))
            os.makedirs(detailed_block_dir, exist_ok=True)

            with open(os.path.join(detailed_block_dir, "step_{:03d}.json".format(curr_step)), "w") as f:
                json.dump(save_curr_step_data, f, indent=4)

    stop_flag = False
    if len(selected_subset)==0:
        curr_meta_arg["logger"].info("[Error] Selected subsets are empty!")
        
        # Log the error to a file in the results directory
        error_log_dir = os.path.join(f"/data5/users/anonymoususer/causal_path/CausalPathTracing_for_ViT/main/jobs/{args.model}")
        os.makedirs(error_log_dir, exist_ok=True)  # Create directory if it doesn't exist
        error_log_path = os.path.join(error_log_dir, "log_error_samples.txt")
        with open(error_log_path, "a") as f:
            f.write(f"sample {args.target_sample_idx} has empty selected subset error.\n")
            
        stop_flag = True
        

    min_search_memory["selected_subset"] = selected_subset
    min_search_memory["stop_step"] = stop_step
    min_search_memory["step_subset"] = step_subset
    min_search_memory["step_subset_cond_bool"] = step_subset_cond_bool

    # Restore orignal module
    for bidx in range(curr_meta_arg["total_block_num"]):
        curr_block = curr_meta_arg["mt"].model.blocks[bidx]
        curr_block = block_initializer(args, curr_block)

    return min_search_memory, stop_flag

def top_down_causal_flow_trace(
    args, curr_meta_arg, flow_tracer):

    expanded_images = curr_meta_arg["images"].repeat(args.num_noise_sample, 1, 1, 1)
    inp = {"pixel_values": expanded_images}
    
    curr_meta_arg["inp"] = inp
    stop_flag = False
    for curr_block_idx in tqdm.tqdm(range(curr_meta_arg["total_block_num"]-1, -1, -1)):
        start_block_time = time.time()
        curr_meta_arg["curr_block_idx"] = curr_block_idx
        curr_meta_arg["logger"].info("-----------Block Change----------- Block:{}".format( curr_meta_arg["curr_block_idx"]))
        
        min_search_memory, stop_flag = minimality_search(args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer)
        if stop_flag:
            curr_meta_arg["logger"].info(f"Stopping early at block {curr_block_idx} due to empty selected subset")
            break
        else:
            flow_tracer.traced_paths.update({curr_block_idx: copy.deepcopy(min_search_memory["selected_subset"])})
       

    return stop_flag