import torch
import numpy as np
import random
import os
import pickle
from torch.utils.data import DataLoader, Subset

from lib.args import arg_func
from lib.utils import *
from lib.causal_flow_tracer import CausalFlowTracer
from lib.search_func import top_down_causal_flow_trace


def main():
    random_seed = 0
    torch.set_grad_enabled(False)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"
    os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

    args = arg_func()
    logger, func_time_saver = set_utils(args)
    
    # Load ImageNet dataset
    dataset, num_classes = get_data(args.dataset)

    if args.target_sample_idx is not None:
        dataset = Subset(dataset, [args.target_sample_idx])

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    # Load ViT model
    mt = get_model(model_name=args.model, num_classes=num_classes, dataset_name=args.dataset)

    # n_lev = get_noise_level(mt)

    func_time_saver.logging(logger, header="\n\n===============Prepare Function Log===============")

    # Data Iteration
    for batch_idx, (images, labels) in enumerate(dataloader):
        if args.debug_index is not None:
            if batch_idx != args.debug_index:
                continue

        # Move images to GPU
        images = images.cuda()
        labels = labels.cuda()
        # Create batch of images for normal and corrupted inputs
        flow_tracer = CausalFlowTracer()
        flow_tracer, answer, curr_total_token_num, curr_total_block_num, normal_inp = predict_from_normal_and_noise_input(
            images, flow_tracer, labels, logger,
            mt, n_lev=None, end_symbol=None, out_num=1, 
            model_name=args.model,
            dataset_name=args.dataset,
            num_normal_sample=args.num_noise_sample, 
            num_noise_sample=args.num_noise_sample, 
            noise_type=args.noise_type)
        
        if (isinstance(answer, list) is False):
            if answer==-1:
                log_line = "--------------------"
                logger.info(log_line)
                log_line = "[Incorrect Answer Pass] Index:{}, Label:{}".format(batch_idx, labels[0].item())
                logger.info(log_line)
                continue

        log_line = "--------------------"
        logger.info(log_line)
        log_line = "[Searching Start] Index:{}, Label:{}, Pred:{}".format(args.target_sample_idx, labels[0].item(), answer)
        logger.info(log_line)
        
        args.curr_save_dir = os.path.join(args.save_result_root, "R{:04d}".format(args.target_sample_idx))
        os.makedirs(args.curr_save_dir, exist_ok=True)
        
        curr_meta_arg = {}
        curr_meta_arg["total_block_num"] = curr_total_block_num
        curr_meta_arg["total_token_num"] = curr_total_token_num
        curr_meta_arg["images"] = images
        curr_meta_arg["mt"] = mt
        curr_meta_arg["logger"] = logger

        stop_flag = False
        stop_flag = top_down_causal_flow_trace(
            args=args, curr_meta_arg=curr_meta_arg, flow_tracer=flow_tracer)
        
        if stop_flag is False:
            success_logger = ["idx:{}\nlabel:{}\npred:{}".format(args.target_sample_idx, labels[0].item(), answer[0])]
            np.savetxt(os.path.join(args.save_inpinfo_root, "I{:06d}.txt".format(args.target_sample_idx)), success_logger, fmt="%s")
            json_data = {
                str(k): [[int(i) for i in t] for t in v]
                for k, v in flow_tracer.traced_paths.items()
            }

            with open(os.path.join(args.curr_save_dir, "C{:06d}.json".format(args.target_sample_idx)), "w") as f:
                f.write("{\n")
                for idx, (k, paths) in enumerate(json_data.items()):
                    f.write(f'  "{k}": [\n')
                    for i, p in enumerate(paths):
                        comma = "," if i < len(paths) - 1 else ""
                        f.write(f"    {p}{comma}\n")
                    f.write("  ]" + ("," if idx < len(json_data) - 1 else "") + "\n")
                f.write("}\n")


if __name__ == "__main__":
    main()