"""
Configurable settings
"""

GPU = True  
INDEX_FILE = "index_ade20k.csv"  
CLEAN = False  
MODEL = "resnet18"  
DATASET = "places365" 
MODEL_CHECKPOINT = None  
MODEL_DATA_PERCENT = None  

PROBE_DATASET = "broden"  
QUANTILE = 0.005  
TOTAL_QUANTILE = 0.01  
SEG_THRESHOLD = 0.04 
SCORE_THRESHOLD = 0.04 
CONTRIBUTIONS = True  
TOPN = 5  
PARALLEL = (
    8  
)
CATEGORIES = [
    "object",
    "part",
    "scene",
    "texture",
    "color",
]  
UNIT_RANGE = None  


EMBEDDING_SUMMARY = False  
WN_SUMMARY = False  
WN_SIMILARITY = False  
SEMANTIC_CONSISTENCY = False  


FORMULA_COMPLEXITY_PENALTY = 1.00  
BEAM_SEARCH_LIMIT = 100  
BEAM_SIZE = 5 
MAX_FORMULA_LENGTH = 1  


TREE_MAXDEPTH = 4  
TREE_MAXCHILDREN = 3  
TREE_UNITS = range(1, 365, 10)  


INDEX_SUFFIX = INDEX_FILE.split("index")[1].split(".csv")
if PROBE_DATASET != "broden" or not INDEX_SUFFIX:
    INDEX_SUFFIX = ""
else:
    INDEX_SUFFIX = INDEX_SUFFIX[0]

TEST_MODE = INDEX_FILE == "index_sm.csv"

mbase = MODEL
if MODEL_DATA_PERCENT is not None:
    mbase = f"{mbase}_{MODEL_DATA_PERCENT}pct"
if MODEL_CHECKPOINT is not None:
    mbase = f"{mbase}_{MODEL_CHECKPOINT}ckpt"

OUTPUT_FOLDER = f"result/{mbase}_{DATASET}_{PROBE_DATASET}{INDEX_SUFFIX}_neuron_{MAX_FORMULA_LENGTH}{'_test' if TEST_MODE else ''}{f'_checkpoint_{MODEL_CHECKPOINT}' if DATASET == 'ade20k' else ''}"

print(OUTPUT_FOLDER)


if PROBE_DATASET == "broden":
    if MODEL != "alexnet":
        DATA_DIRECTORY = "dataset/broden1_224"
        IMG_SIZE = 224
    else:
        DATA_DIRECTORY = "dataset/broden1_227"
        IMG_SIZE = 227
else:
    raise NotImplementedError(f"Unknown dataset {PROBE_DATASET}")

if DATASET == "places365":
    NUM_CLASSES = 365
elif DATASET == "imagenet":
    NUM_CLASSES = 1000
elif DATASET == "ade20k":
    NUM_CLASSES = 365

if MODEL not in {"resnet18", "resnet50", "densenet161", "renset101", "alexnet", "vgg16"}:
    raise NotImplementedError(f"model = {MODEL}")

if MODEL == "resnet18":
    FEATURE_NAMES = [
        'layer4'
    ]
elif MODEL == 'resnet50':
    FEATURE_NAMES = ['layer4']
elif MODEL == "resnet101":
    FEATURE_NAMES = ["layer4"]
elif MODEL == "densenet161":
    FEATURE_NAMES = ["features"]
elif MODEL == "alexnet":
    FEATURE_NAMES = ['features']
elif MODEL == "vgg16":
    FEATURE_NAMES = ["layer4"]

if DATASET == "places365":
    if MODEL_CHECKPOINT is None and MODEL_DATA_PERCENT is None:
        if MODEL == 'densenet161':
            MODEL_FILE = 'zoo/whole_densenet161_places365_python36.pth.tar'
        else:
            MODEL_FILE = f"zoo/{MODEL}_places365.pth.tar"
    else:
        datapctstr = f"_{MODEL_DATA_PERCENT}" if MODEL_DATA_PERCENT is not None else ""
        MODEL_FILE = f"zoo/trained/places365/resnet18{datapctstr}/resnet18_{MODEL_CHECKPOINT}.pth.tar"
    MODEL_PARALLEL = True
elif DATASET == "imagenet":
    MODEL_FILE = None
    MODEL_PARALLEL = False
elif DATASET == "ade20k":
    MODEL_FILE = f"zoo/trained/{mbase}_ade20k_finetune/{MODEL_CHECKPOINT}.pth"
    MODEL_PARALLEL = False
elif DATASET is None:
    MODEL_FILE = "<UNTRAINED>"
    MODEL_PARALLEL = False

if TEST_MODE:
    WORKERS = 1
    BATCH_SIZE = 4
    TALLY_BATCH_SIZE = 2
    TALLY_AHEAD = 1
else:
    WORKERS = 12
    BATCH_SIZE = 128
    TALLY_BATCH_SIZE = 16
    TALLY_AHEAD = 4
    
    