#¡/usr/bin/env python

import os
import subprocess
import re
import random 
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="scripts/pytorch_script_rex_mask_IN-1k_mbnet_v4.py")
parser.add_argument("--config", type=str, default="toml_files_mbnet_v4/IN-1k/rex_mask_val_avg_in1k_mbnet_v4.toml")
parser.add_argument("--database", type=str, default="ImageNet-onek/Results_mbnet_v4/rex_mask_seed_42_threshold_0.9.db")
parser.add_argument("--dir", type=str, default="ImageNet-onek/IN-onek_data")
parser.add_argument("--output_dir", type=str, default="ImageNet-onek/Results_mbnet_v4/MBNet_v4_masking_seed_42_threshold_0.9")

args = parser.parse_args()

out = os.path.join(args.output_dir, args.output_dir.rsplit("/", 1)[-1] + ".csv")

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

random.seed(42)
all_files = []
for root, dirs, files in os.walk(args.dir):
    for file in files:
        if file.endswith('.jpeg'):
            all_files.append(os.path.join(root, file))

#Only run on the first 150 images
all_files = all_files[:150]

with open(out, "a") as e:
    # e.write(f"Filename, actual classification, predicted classification, area, responsibility entropy, max entropy, insertion curve, deletion curve\n")
    e.write(f"Filename, target, classification, area, KL Divergence, robustness, insertion curve, deletion curve, time\n")

for name in tqdm(all_files):
        fp = name
        rr = None
        process = subprocess.Popen(['ReX', fp, "--script", args.model, "--config", args.config, "--analyse", "--database", args.database, '--output', f'{args.output_dir}/{fp.split("/")[-1].split(".jpeg")[0]}.png'], stdout=subprocess.PIPE)
        for line in process.stdout: #type: ignore
            line = line.decode().strip()
            print(line)
            rr = re.findall(r"[+-]? *(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?", line)[2:]
            rr = ",".join([r.strip() for r in rr])
        with open(out, "a") as e:
            e.write(f"{fp},{rr}\n")
        # print("finished:", fp, rr, flush=True)

