#¡/usr/bin/env python

import os
import subprocess
import re
import random 
import argparse
from tqdm import tqdm
import sys

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="scripts/pytorch_script_rex_AD_CT-256_mbnet_v4.py")
parser.add_argument("--config", type=str, default="toml_files_mbnet_v4/CT-256/rex_AD_CT-256_mbnet_v4.toml")
parser.add_argument("--database", type=str, default="CalTech-256/Results_mbnet_v4/rex_AD_seed_42_threshold_0.9.db")
parser.add_argument("--dir", type=str, default="CalTech-256/Dataset/exp_test")
parser.add_argument("--output_dir", type=str, default="CalTech-256/Results_mbnet_v4/MBNet_v4_AD_seed_42_threshold_0.9")

args = parser.parse_args()

out = os.path.join(args.output_dir, args.output_dir.rsplit("/", 1)[-1] + ".csv")

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

random.seed(42)
all_files = []
for root, dirs, files in os.walk(args.dir):
    for file in files:
        if file.endswith('.jpeg') or file.endswith('.jpg'):
            all_files.append(os.path.join(root, file))

# Only the first 150 files
all_files = all_files[:150]

with open(out, "a") as e:
    e.write(f"Filename, actual classification, predicted classification, area, responsibility entropy, insertion curve, deletion curve, time\n")
    # e.write(f"Filename, target, classification, area, KL Divergence, robustness, insertion curve, deletion curve, time\n")

for name in tqdm(all_files):
        fp = name
        rr = None
        process = subprocess.Popen(['ReX', fp, "--script", args.model, "--config", args.config, "--analyse", "--database", args.database, '--output', f'{args.output_dir}/{fp.split("/")[-1].split(".jpeg")[0]}.png'], stdout=subprocess.PIPE)
        for line in process.stdout: #type: ignore
            line = line.decode().strip()
            print(line)
            rr = re.findall(r"[+-]? *(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?", line)
            # if "entropy" in line:
                # rr = re.findall(r"[+-]? *(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?", line)
            rr = rr[3:]
            rr[0] = str(int(rr[0])-1)
            print(rr)
            rr = ",".join([r.strip() for r in rr])
        with open(out, "a") as e:
            e.write(f"{fp},{rr}\n")
        # print("finished:", fp, rr, flush=True)

