import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default="6", type=str, help='gpu')
parser.add_argument('--result_folder', required=True, type=str)
parser.add_argument('--path_agg_type', default="union", choices=['union', 'union_calib', 'min', 'min_calib', 'intersection', 'over1', 'top_min_path_size'], help="agg type")
parser.add_argument('--interv_target', default="off_path", choices=['off_path', 'on_path'], help="interv_target type")

parser.add_argument('--num_noise_sample', default=100, type=int)
parser.add_argument('--model', default="gpt-xl", type=str)
parser.add_argument('--noise_type', default="other", choices=['other', 'emb_added', 'zero','mean'], help="noise type")
parser.add_argument('--dataset_type', default="rome_factual_1000", choices=['rome_factual_1000', 'lama_trex'], help="dataset_type")
parser.add_argument('--except_stopword', help='except_stopword', action='store_true')

parser.add_argument('--model_save', default=None, type=str)

args = parser.parse_args()

SAVE_JSON_NAME = "J_{}_{}.json".format(
    args.path_agg_type,
    "path" if args.interv_target == "off_path" else "off_path"
)

if args.model == "gpt2-mini":
    args.model = "erwanf/gpt2-mini"
elif args.model == "gpt2-xs":
    args.model = "AlgorithmicResearchGroup/gpt2-xs"
elif args.model == "pythia-14m":
    args.model = "EleutherAI/pythia-14m"
elif args.model == "pythia-1b":
    args.model = "EleutherAI/pythia-1b"
# GPU
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

import json
import torch
import numpy as np
import random
import pickle

from collections import Counter
from lib.utils import *
from lib.search_func import *
from lib.causal_flow_tracer import CausalFlowTracer, return_forward_method_dict




def main():
    random_seed = 0
    torch.set_grad_enabled(False)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"

    causal_path_result_zip = {}
    results_path = os.path.join(args.result_folder, "results")
    for curr_result in sorted(os.listdir(results_path)):
        idx = int(curr_result.split("R")[-1])
        json_path = os.path.join(results_path, curr_result, "C{:06d}.json".format(idx))
        f = open(json_path,"r")
        path_data = json.load(f)
        
        curr_result_zip = {}
        for block_idx, paths in path_data.items():
            if "union" in args.path_agg_type:
                sets = [set(lst) for lst in paths]
                target_paths = list(set.union(*sets))
            elif "min" in args.path_agg_type:
                min_len = min(len(lst) for lst in paths)
                target_paths = [lst for lst in paths if len(lst) == min_len][0]
            elif args.path_agg_type == "intersection":
                sets = [set(lst) for lst in paths]
                target_paths = set.intersection(*sets)
            elif args.path_agg_type == "over1":
                flattened = [item for sublist in paths for item in sublist]
                counter = Counter(flattened)
                target_paths = sorted([item for item, count in counter.items() if count >= 2])
            elif args.path_agg_type == "top_min_path_size":
                min_len = min(len(lst) for lst in paths)
                flattened = [item for sublist in paths for item in sublist]
                counter = Counter(flattened)
                target_paths = [item for item, count in counter.most_common(min_len)]
            curr_result_zip.update({int(block_idx): target_paths})
        causal_path_result_zip.update({idx: curr_result_zip})
    
    
    line_data = get_data(args)
    mt = get_model(args)
    n_lev = get_noise_level(args, mt)
    
    args.out_num=1
    args.end_symbol=[]
    
    sample_wise_result = {
        "org_decision": [],
        "path_decision": [],
        "org_output": [],
        "path_output": [],
        "org_pred_order": [],
        "path_pred_order": []
    }
    cnt=0
    prgs = tqdm.tqdm(enumerate(line_data))
    for li, statement in prgs:
        
        if li not in list(causal_path_result_zip.keys()):
            continue
        
        prompt = statement["prompt"]
        y = statement["attribute"]

        flow_tracer = CausalFlowTracer(args)
        flow_tracer, answer, curr_total_token_num, curr_total_block_num, normal_inp = predict_from_normal_and_noise_input(
            prompt, flow_tracer, y, None,
            mt, n_lev, end_symbol=[], 
            out_num=1, num_normal_sample=args.num_noise_sample, 
            num_noise_sample=args.num_noise_sample, noise_type=args.noise_type)
        
        curr_meta_arg = {}
        curr_meta_arg["total_block_num"] = curr_total_block_num
        curr_meta_arg["total_token_num"] = curr_total_token_num
        curr_meta_arg["prompt"] = prompt
        curr_meta_arg["mt"] = mt
        inp = make_inputs(
            curr_meta_arg["mt"].tokenizer, [curr_meta_arg["prompt"]] * (args.num_noise_sample))
        
        curr_meta_arg["inp"] = inp
        
        detailed_flow_tracer = {}
        for detailed_flow_tracer_idx in range(0, curr_meta_arg["total_block_num"], 1):
            curr_detailed_corrupted_flow_tracer = get_detailed_corrupted_flow_tracer_one_block(
                args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer,
                detailed_flow_tracer_idx=detailed_flow_tracer_idx)
            
            detailed_flow_tracer.update({
                detailed_flow_tracer_idx: {"corrupted": curr_detailed_corrupted_flow_tracer}
                })

        curr_meta_arg["detailed_flow_tracer"] = detailed_flow_tracer
        
        if args.except_stopword is False:
            normal_p, normal_pred = torch.max(flow_tracer.scores_normal, dim=0)
            desc_idx = torch.argsort(flow_tracer.scores_normal, dim=0, descending=True)
            normal_output = torch.index_select(flow_tracer.scores_normal, 0, desc_idx)
            pred_ord = desc_idx
        else:
            desc_idx = torch.argsort(flow_tracer.scores_normal, dim=0, descending=True)
            sorted_stwd_mask = flow_tracer.stwd_mask[desc_idx]
            preds = desc_idx[sorted_stwd_mask]
            normal_pred = preds[0]
            normal_p = flow_tracer.scores_normal[normal_pred]
            normal_output = torch.index_select(flow_tracer.scores_normal, 0, preds)
            pred_ord = preds
            
                
        curr_meta_arg["normal_p"] = normal_p
        curr_meta_arg["normal_pred"] = normal_pred
        curr_meta_arg["normal_output"] = normal_output
        curr_meta_arg["normal_pred_order"] = pred_ord
        if hasattr(curr_meta_arg['mt'].model.config, "use_parallel_residual"):
            curr_meta_arg["num_path_node"] = curr_meta_arg['mt'].model.config.num_attention_heads + 2
        else:
            curr_meta_arg["num_path_node"] = curr_meta_arg['mt'].model.config.n_head *2 + 2

    
        for bidx in range(curr_meta_arg["total_block_num"]):
            if hasattr(curr_meta_arg["mt"].model, "transformer"):
                curr_block = curr_meta_arg["mt"].model.transformer.h[bidx]
            elif hasattr(curr_meta_arg["mt"].model, "gpt_neox"):
                curr_block = curr_meta_arg["mt"].model.gpt_neox.layers[bidx]
            curr_block = block_initializer(args, curr_block)
            
            fwd_method_dict = return_forward_method_dict(args)

            curr_block.forward = types.MethodType(fwd_method_dict['intervention_forward'], curr_block)
            curr_block.mlp.forward = types.MethodType(fwd_method_dict['custom_mlp_forward'], curr_block.mlp)
            if hasattr(curr_block, "attn"):
                curr_block.attn.forward = types.MethodType(fwd_method_dict["custom_attn_forward"], curr_block.attn)
            elif hasattr(curr_block, "attention"):
                curr_block.attention.forward = types.MethodType(fwd_method_dict["custom_attn_forward"], curr_block.attention)

            curr_block.anonymoususer_trace_mode = True
            curr_block.anonymoususer_cond = "path"
            curr_block.anonymoususer_corrupted_feats = curr_meta_arg["detailed_flow_tracer"][bidx]["corrupted"]
            
            if args.interv_target=="off_path":
                curr_block.anonymoususer_curr_subset = tuple(causal_path_result_zip[li][bidx])
            elif args.interv_target=="on_path":
                curr_block.anonymoususer_curr_subset = tuple(np.setdiff1d(np.arange(curr_meta_arg["num_path_node"]), causal_path_result_zip[li][bidx]).tolist())
            else:
                import pdb; pdb.set_trace()
        
        with torch.no_grad():
            curr_pred, curr_p, raw_out = predict_from_input(
                curr_meta_arg["mt"].model, curr_meta_arg["inp"], 
                multipred=(args.out_num!=1), end_symbol=args.end_symbol, use_mean=True, stwd_mask=flow_tracer.stwd_mask)
            raw_out_sftmx = torch.softmax(raw_out[:, -1], dim=1).mean(dim=0).unsqueeze(0)
            desc_idx = torch.argsort(raw_out_sftmx, dim=1, descending=True)
            
            if args.except_stopword is False:
                output = torch.index_select(raw_out_sftmx, 1, desc_idx.squeeze(0)).squeeze(0)
                pred_order = desc_idx.squeeze(0)
            else:
                sorted_stwd_mask = flow_tracer.stwd_mask[desc_idx]
                preds = desc_idx[sorted_stwd_mask]
                output = torch.index_select(raw_out_sftmx, 1, preds).squeeze(0)
                
                pred_order = preds

        for bidx in range(curr_meta_arg["total_block_num"]):
            if hasattr(curr_meta_arg["mt"].model, "transformer"):
                curr_block = curr_meta_arg["mt"].model.transformer.h[bidx]
            elif hasattr(curr_meta_arg["mt"].model, "gpt_neox"):
                curr_block = curr_meta_arg["mt"].model.gpt_neox.layers[bidx]
            curr_block = block_initializer(args, curr_block)
            
        token_id_to_index = {token.item(): i for i, token in enumerate(pred_order)}
        sorted_pred_order = torch.tensor([token_id_to_index[i.item()] for i in curr_meta_arg["normal_pred_order"]])
        

        if "calib" in args.path_agg_type:
            if curr_meta_arg["normal_pred"].item()!=curr_pred.item():
                continue
        sample_wise_result["org_decision"].append(curr_meta_arg["normal_pred"].item())
        sample_wise_result["path_decision"].append(curr_pred.item())
        sample_wise_result["org_output"].append(curr_meta_arg["normal_output"].cpu())
        sample_wise_result["path_output"].append(output[sorted_pred_order].cpu())
        # sort following normal prediction!!!!!!!!!!!!!!!!!!
        cnt+=1
        prgs.set_description("Processing line {:4d}/{:4d}".format(cnt, len(list(causal_path_result_zip.keys()))))
        
    sample_wise_result["org_decision"] = torch.tensor(sample_wise_result["org_decision"])
    sample_wise_result["path_decision"] = torch.tensor(sample_wise_result["path_decision"])
    sample_wise_result["org_output"] = torch.vstack(sample_wise_result["org_output"])
    sample_wise_result["path_output"] = torch.vstack(sample_wise_result["path_output"])
    
    
    sample_wise_result["metric"] = {}
    sample_wise_result["metric"]["hit_rate"] = (sample_wise_result["org_decision"]==sample_wise_result["path_decision"]).float().mean().item()

    sample_wise_result["metric"]["top1_faithful"] = (sample_wise_result["path_output"][:, 0]/sample_wise_result["org_output"][:, 0]).mean().item()
    # Sorted by the original argsort indices: 
    # higher values indicate that the model can make more distinctive (confident, good) predictions, 
    # while lower values suggest the model struggles (don't know) to do so.

    sample_wise_result["metric"]["top1_decision_improv"] = (sample_wise_result["path_output"][:, 0] - sample_wise_result["org_output"][:, 0]).mean().item()
    
    
    org_top2_logit_diff = sample_wise_result["org_output"][:, 0] - sample_wise_result["org_output"][:, 1]
    path_top2_logit_diff = sample_wise_result["path_output"][:, 0] - sample_wise_result["path_output"][:, 1]

    sample_wise_result["metric"]["top2_faithful"] = (path_top2_logit_diff/org_top2_logit_diff).mean().item()
    # Sorted by the original argsort indices: 
    # higher values indicate that the model can make more distinctive (confident, good) predictions, 
    # while lower values suggest the model struggles (don't know) to do so.
    
    sample_wise_result["metric"]["top2_decision_improv"] = (path_top2_logit_diff - org_top2_logit_diff).mean().item()
    
    # all faithfulness
    org_all_logit_diff = torch.vstack([sample_wise_result["org_output"][:, 0]-sample_wise_result["org_output"][:, i] for i in range(1, sample_wise_result["org_output"].shape[1])])
    path_all_logit_diff = torch.vstack([sample_wise_result["path_output"][:, 0]-sample_wise_result["path_output"][:, i] for i in range(1, sample_wise_result["path_output"].shape[1])])
    sample_wise_result["metric"]["all_faithful"] = (path_all_logit_diff.mean(0)/org_all_logit_diff.mean(0)).mean().item()
    # Sorted by the original argsort indices: 
    # higher values indicate that the model can make more distinctive (confident, good) predictions, 
    # while lower values suggest the model struggles (don't know) to do so.
    
    sample_wise_result["metric"]["all_decision_improv"] = (path_all_logit_diff.mean(0) - org_all_logit_diff.mean(0)).mean().item()
    
    path_forward_results_dir = os.path.join(args.result_folder, "path_forward_results")
    os.makedirs(path_forward_results_dir, exist_ok=True)
    
    f = open(os.path.join(path_forward_results_dir, SAVE_JSON_NAME), "w")
    json.dump(sample_wise_result["metric"], f, indent=4 )
    
if __name__ == "__main__":
    main()