import os
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
from lib.args import arg_func
args = arg_func()
# GPU
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

import torch
import numpy as np
import random
import pickle

from lib.utils import *
from lib.causal_flow_tracer import CausalFlowTracer
from lib.search_func import top_down_causal_flow_trace


def main():
    random_seed = 0
    torch.set_grad_enabled(False)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"

    logger, func_time_saver = set_utils(args)
    line_data = get_data(args, reverse=args.reverse)
    
    mt = get_model(args)
    n_lev = get_noise_level(args, mt)
    func_time_saver.logging(logger, header="\n\n===============Prepare Function Log===============")
    
    if args.correct_check_first:
        logger.info("[Correct Check Start]")
        func_time_saver.logging(logger, header="\n\n===============Correct Check Log===============")
        # Data Iteration
        correct_ids = []
        for li, statement in tqdm.tqdm(enumerate(line_data), desc="Correct Check First"):
            prompt = statement["prompt"]
            y = statement["attribute"]

            flow_tracer = CausalFlowTracer(args)
            correctness = predict_from_normal_and_noise_input(
                prompt, flow_tracer, y, logger,
                mt, n_lev, end_symbol=args.end_symbol, 
                out_num=1, num_normal_sample=args.num_noise_sample, 
                num_noise_sample=args.num_noise_sample, noise_type=args.noise_type, correct_check_only=True)
            if correctness:
                correct_ids.append(li)
        logger.info("\t Correct Data Num.: {}".format(len(correct_ids)))
        np.savetxt(os.path.join(args.save_root, "correct_data_idx.txt"), np.asarray(correct_ids), fmt="%s")

    if args.target_sample_index_list is not None:
        args.target_sample_index_list = [int(i) for i in args.target_sample_index_list.split(", ")]
        
    # Data Iteration
    for li, statement in tqdm.tqdm(enumerate(line_data)):
        
        if args.target_sample_index_list is not None:
            if li not in args.target_sample_index_list:
                continue
        
        if args.target_sample_index is not None:
            spt = args.target_sample_index.split("_")
            if li<int(spt[0]):
                continue
            if spt[1]!="-":
                if li > int(spt[1]):
                    conitnue

        if args.debug_index is not None:
            if li != args.debug_index:
                continue

        prompt = statement["prompt"]
        y = statement["attribute"]

        flow_tracer = CausalFlowTracer(args)
        flow_tracer, answer, curr_total_token_num, curr_total_block_num, normal_inp = predict_from_normal_and_noise_input(
            prompt, flow_tracer, y, logger,
            mt, n_lev, end_symbol=args.end_symbol, 
            out_num=1, num_normal_sample=args.num_noise_sample, 
            num_noise_sample=args.num_noise_sample, noise_type=args.noise_type)
        
        if (isinstance(answer, list) is False):
            if answer==-1:
                log_line = "--------------------"
                logger.info(log_line)
                log_line = "[Incorrect Answer Pass] Index:{}, Prompt:{}".format(li, prompt)
                logger.info(log_line)
                continue

        log_line = "--------------------"
        logger.info(log_line)
        log_line = "[Searching Start] Index:{}, Prompt:{}, Ans:{}".format(li, prompt, y)
        logger.info(log_line)
        args.curr_save_dir = os.path.join(args.save_result_root, "R{:04d}".format(li))
        os.makedirs(args.curr_save_dir, exist_ok=True)
        
        curr_meta_arg = {}
        curr_meta_arg["total_block_num"] = curr_total_block_num
        curr_meta_arg["total_token_num"] = curr_total_token_num
        curr_meta_arg["prompt"] = prompt
        curr_meta_arg["mt"] = mt
        curr_meta_arg["logger"] = logger

        stop_flag = False
        stop_flag = top_down_causal_flow_trace(
            args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer)
        
        if stop_flag is False:
            success_logger = ["li:{}\nprompt:{}\ny:{}".format(li, prompt, y) ]
            np.savetxt(os.path.join(args.save_inpinfo_root, "I{:06d}.txt".format(li)), success_logger, fmt="%s")
            # json_data = { str(k): [[int(i) for i in t] for t in v] for k, v in flow_tracer.traced_paths.items()}
            # with open(os.path.join(args.curr_save_dir, "C{:06d}.json".format(li)), "w") as f:
            #     json.dump(json_data, f, indent=4)

            json_data = {
                str(k): [[int(i) for i in t] for t in v]
                for k, v in flow_tracer.traced_paths.items()
            }

            with open(os.path.join(args.curr_save_dir, "C{:06d}.json".format(li)), "w") as f:
                f.write("{\n")
                for idx, (k, paths) in enumerate(json_data.items()):
                    f.write(f'  "{k}": [\n')
                    for i, p in enumerate(paths):
                        comma = "," if i < len(paths) - 1 else ""
                        f.write(f"    {p}{comma}\n")
                    f.write("  ]" + ("," if idx < len(json_data) - 1 else "") + "\n")
                f.write("}\n")
                        
            


if __name__ == "__main__":

    main()