#¡/usr/bin/env python

import os
import subprocess
import re
import random 
import argparse
from tqdm import tqdm
import sys
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="scripts/pytorch_script_rex_mask_CT-256_mbnet_v4.py")
parser.add_argument("--config", type=str, default="toml_files_mbnet_v4/CT-256/rex_mask_val_avg_CT-256_mbnet_v4.toml")
parser.add_argument("--database", type=str, default="CalTech-256/Results_mbnet_v4/rex_mask_seed_42_threshold_0.9.db")
parser.add_argument("--dir", type=str, default="CalTech-256/Dataset/exp_test")
parser.add_argument("--output_dir", type=str, default="CalTech-256/Results_mbnet_v4/MBNet_v4_masking_seed_42_threshold_0.9")
parser.add_argument("--pickle", type=str, default="outfile.pkl")

args = parser.parse_args()

out = os.path.join(args.output_dir, args.output_dir.rsplit("/", 1)[-1] + ".csv")

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

if not args.pickle:
    random.seed(42)
    all_files = []
    for root, dirs, files in os.walk(args.dir):
        for file in files:
            if file.endswith('.jpeg') or file.endswith('.jpg'):
                all_files.append(os.path.join(root, file))

    #Only run on the first 150 images
    all_files = all_files[:150]
else:
    with open(args.pickle, 'rb') as f:
        all_files = pickle.load(f)

with open(out, "a") as e:
    e.write(f"Filename, actual classification, predicted classification, area, responsibility entropy, insertion curve, deletion curve, time\n")

for name in tqdm(all_files):
        fp = name
        rr = None
        process = subprocess.Popen(['ReX', fp, "--script", args.model, "--config", args.config, "--analyse", "--database", args.database, '--output', f'{args.output_dir}/{fp.split("/")[-1].split(".jpeg")[0]}.png',], stdout=subprocess.PIPE)
        for line in process.stdout: #type: ignore
            line = line.decode().strip()
            print(line)
            # if "entropy" in line:
            rr = re.findall(r"[+-]? *(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?", line)
            rr = rr[3:]
            rr[0] = str(int(rr[0])-1)
            print(rr)
            rr = ",".join([r.strip() for r in rr])
        with open(out, "a") as e:
            # print(rr)
            e.write(f"{fp},{rr}\n")
        # print("finished:", fp, rr, flush=True)