#¡/usr/bin/env python

import os
import subprocess
import re
import pickle
import random 
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="scripts/pytorch_script_rex_mask_IN-1k_regnety.py")
parser.add_argument("--config", type=str, default="rex_in1kv2.toml")
parser.add_argument("--database", type=str, default="IN-1k_v2/Results_regnet/rex_regnet_mask_seed_42_threshold_0.9.db")
parser.add_argument("--dir", type=str, default="IN-1k_v2/Dataset")
parser.add_argument("--output_dir", type=str, default="IN-1k_v2/Results_regnet/RegNet_masking_seed_42_threshold_0.9")
parser.add_argument("--pickle", type=str, default="outfile.pkl")

args = parser.parse_args()

out = os.path.join(args.output_dir, args.output_dir.rsplit("/", 1)[-1] + ".csv")

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

if not args.pickle:
    random.seed(42)
    all_files = []
    for root, dirs, files in os.walk(args.dir):
        for file in files:
            if file.endswith('.jpeg'):
                all_files.append(os.path.join(root, file))

    #Only run on the first 150 images
    all_files = all_files[:150]
else:
    with open(args.pickle, 'rb') as f:
        all_files = pickle.load(f)

with open(out, "a") as e:
    # e.write(f"Filename, actual classification, predicted classification, area, responsibility entropy, max entropy, insertion curve, deletion curve\n")
    e.write(f"Filename, target, classification, area, KL Divergence, robustness, insertion curve, deletion curve, time\n")

for name in tqdm(all_files):
        fp = name
        rr = None
        process = subprocess.Popen(['ReX', fp, "--script", args.model, "--config", args.config, "--analyse", "--database", args.database, '--output', f'{args.output_dir}/{fp.split("/")[-1].split(".jpeg")[0]}.png'], stdout=subprocess.PIPE)
        for line in process.stdout: #type: ignore
            line = line.decode().strip()
            # print(line)
            rr = re.findall(r"[+-]? *(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?", line)

            found_class = False
            new_rr = []
            print(line)
            for r in rr:
                if found_class:
                     new_rr.append(r)
                elif r.endswith("."):
                     found_class = True
                     new_rr.append(r.strip("."))

            rr = ",".join([r.strip() for r in new_rr])
        with open(out, "a") as e:
            e.write(f"{fp},{rr}\n")
        # print("finished:", fp, rr, flush=True)

