import os
import argparse
import torch
import matplotlib.pyplot as plt

torch.set_default_dtype(torch.float64)
torch.seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
script_dir = os.path.dirname(os.path.realpath(__file__))


def process_nerf_config():
    # Create an ArgumentParser object
    parser = argparse.ArgumentParser(description="NeRF configuration")

    # Add arguments with default values
    parser.add_argument('--dataname', type=str, default='airplane', help='Dataset name')
    parser.add_argument('--n_samples', type=int, default=32, help='Number of samples')
    parser.add_argument('--n_layers', type=int, default=2, help='Number of layers')
    parser.add_argument('--d_filter', type=int, default=128, help='Filter size')

    parser.add_argument('--n_iters', type=int, default=100000, help='Number of iterations')
    parser.add_argument('--chunksize', type=int, default=2**5, help='Chunk size')

    parser.add_argument('--xyz_step', type=float, default=0.0002, help='XYZ step size')
    parser.add_argument('--ry_step', type=float, default=0.00010, help='RY step size')
    parser.add_argument('--xyz_eps', type=float, default=0.0006, help='XYZ epsilon')
    parser.add_argument('--ry_eps', type=float, default=0.0003, help='RY epsilon')

    parser.add_argument('--hue_step', type=float, default=0.00007, help='Hue step size')
    parser.add_argument('--satur_step', type=float, default=0.00012, help='Saturation step size')
    parser.add_argument('--hue_eps', type=float, default=0.00007, help='Hue epsilon')
    parser.add_argument('--satur_eps', type=float, default=0.00012, help='Saturation epsilon')

    parser.add_argument('--hue_offset', type=float, default=0.0, help='Hue offset')
    parser.add_argument('--satur_offset', type=float, default=0.0, help='Saturation offset')
    parser.add_argument('--input_type', type=str, default='y', help='Input type')
    parser.add_argument('--num_sampling', type=int, default=0, help='Number of samples for testing')

    parser.add_argument('--testimgidx', type=int, default=13, help='Test image index')
    parser.add_argument('--visual_flag', type=bool, default=True, help='Visual flag')
    parser.add_argument('--bound_whole_flag', type=bool, default=True, help='Bounding flag for whole scene')
    parser.add_argument('--xdown_factor', type=int, default=5, help='X downscale factor')
    parser.add_argument('--ydown_factor', type=int, default=3.75, help='Y downscale factor')

    parser.add_argument('--tile_height', type=int, default=32, help='Tile height')
    parser.add_argument('--tile_width', type=int, default=32, help='Tile width')

    parser.add_argument('--print_flag', type=bool, default=False, help='Print flag')
    parser.add_argument('--save_npz_flag', type=bool, default=True, help='Save NPZ flag')
    parser.add_argument('--save_img_flag', type=bool, default=True, help='Save image flag')
    parser.add_argument('--save_img_sep_flag', type=bool, default=False, help='Save image separately flag')

    parser.add_argument('--hue_min', type=int, default=-30, help='Minimum hue')
    parser.add_argument('--hue_max', type=int, default=30, help='Maximum hue')
    parser.add_argument('--sat_min', type=float, default=-0.5, help='Minimum saturation')
    parser.add_argument('--sat_max', type=float, default=0.5, help='Maximum saturation')

    parser.add_argument('--near', type=float, default=2.0, help='Near plane')
    parser.add_argument('--far', type=float, default=6.0, help='Far plane')
    parser.add_argument('--distance_to_infinity', type=float, default=1e2, help='Distance to infinity')

    parser.add_argument('--perturb', type=bool, default=False, help='Perturbation flag')
    parser.add_argument('--inverse_depth', type=bool, default=False, help='Inverse depth flag')
    parser.add_argument('--n_samples_hierarchical', type=int, default=0, help='Number of hierarchical samples')

    parser.add_argument('--d_input', type=int, default=3, help='Input dimension')
    parser.add_argument('--env_input', type=int, default=2, help='Environment input dimension')
    parser.add_argument('--n_freqs', type=int, default=10, help='Number of frequencies')
    parser.add_argument('--log_space', type=bool, default=True, help='Logarithmic space flag')
    parser.add_argument('--n_freqs_views', type=int, default=4, help='Number of view frequencies')
    parser.add_argument('--skip', type=list, default=[], help='Skip list')
    parser.add_argument('--raw_noise_std', type=float, default=0.0, help='Raw noise standard deviation')

    parser.add_argument('--conf_threshold', type=float, default=-1.5, help='Detection confidence threshold')
    parser.add_argument('--manual_focal_factor', type=float, default=1.0, help='Manual focal factor')

    # Parse the arguments
    args = parser.parse_args()

    # Access the arguments within the function
    print("Configurations:")
    print(f"Dataset name: {args.dataname}")
    print(f"Number of samples: {args.n_samples}")
    print(f"Number of layers: {args.n_layers}")
    print(f"Filter size: {args.d_filter}")
    print(f"Number of iterations: {args.n_iters}")
    print(f"Chunk size: {args.chunksize}")
    print(f"input_type: {args.input_type}")
    input_type=args.input_type
    if input_type=="xyz" or input_type=="x" or input_type=="y" or input_type=="z":
        print(f"Perturbation: {args.xyz_eps}")
    elif input_type=="ry" or input_type=="roll" or  input_type=="yaw":
        print(f"Perturbation: {args.ry_eps}")
    elif input_type=="xyzry":
        print(f"Perturbation: {args.xyz_eps,args.ry_eps}")
    elif input_type=="hue":
        print(f"Perturbation: {args.hue_eps}")
    elif input_type=="satur":
        print(f"Perturbation: {args.satur_eps}")
    elif input_type=="env":
        print(f"Perturbation: {args.hue_eps,args.satur_eps}")

    print(f"Visual flag: {args.visual_flag}")
    print(f"Save image flag: {args.save_img_flag}")

    return args


if __name__ == "__main__":

    # Call process_nerf_config to get the arguments
    args = process_nerf_config()

    mark = f"{args.input_type}_{args.dataname}_{args.testimgidx}_xyzeps{args.xyz_eps}_xyzstep{args.xyz_step}_focalfactor{args.manual_focal_factor}"
    save_dir = os.path.join(script_dir, 'yolo_results', mark)

    bounded_conf_max_all = torch.load(os.path.join(save_dir, 'bounded_conf_max_all.pt'))
    unperturbed_conf_all = torch.load(os.path.join(save_dir, 'unperturbed_conf_all.pt'))
    max_conf_idx_all = torch.load(os.path.join(save_dir, 'max_conf_idx_all.pt'))
    reg_lower_all = torch.load(os.path.join(save_dir, 'reg_lower_all.pt'))
    reg_upper_all = torch.load(os.path.join(save_dir, 'reg_upper_all.pt'))

    plt.plot(bounded_conf_max_all, label='Bounded Conf Max')
    plt.plot(unperturbed_conf_all, label='Unperturbed Conf')
    plt.savefig(os.path.join(save_dir, 'conf_plot.png'))
    plt.close()

    overthreshold = torch.sigmoid(torch.tensor([bounded_conf_max_all])) > args.conf_threshold
    print("Verified rate: ", overthreshold.sum().item() / len(overthreshold.flatten()))
    torch.save(overthreshold, os.path.join(save_dir, 'overthreshold.pt'))

    overthreshold_unperturbed = torch.sigmoid(torch.tensor([unperturbed_conf_all])) > args.conf_threshold
    print("Verified rate unperturbed: ", overthreshold_unperturbed.sum().item() / len(overthreshold_unperturbed.flatten()))
    torch.save(overthreshold_unperturbed, os.path.join(save_dir, 'overthreshold_unperturbed.pt'))

    plt.figure()
    for i, val in enumerate(overthreshold.flatten()):
        color = 'green' if val == 1 else 'red'
        plt.vlines(i, ymin=0, ymax=1, color=color, linewidth=2)
    plt.savefig(os.path.join(save_dir, 'overthreshold_plot.png'))
    plt.close()
    