import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import json
import pickle
import random

SEED = 10
ARCH_TYPE = sys.argv[1]
LEVEL = sys.argv[2]
NUM_CLASSES = sys.argv[3]
SCORE_FUNC_TYPE = sys.argv[4]

MAIN_DIR = ''
DOMAIN_CLASSIFIER_DIR = os.path.join(MAIN_DIR, f'best_domain_classifiers_{LEVEL}_{NUM_CLASSES}/{ARCH_TYPE}')
SCORES_DIR = os.path.join(MAIN_DIR, f"calibration_scores/{ARCH_TYPE}/{LEVEL}_{NUM_CLASSES}")
VAL_DIR = os.path.join(MAIN_DIR, 'data/val_data')
INFO_DIR = os.path.join(MAIN_DIR, 'imagenet_class_hierarchy/modified')          # path to the class hierarchy information. Download it from BREEDS GitHub repo
BATCH_SIZE = 16
LAM_REG = 0.01          # param for RAPS
K_REG = 5               # param for RAPS
RAND = True             # whether to use randomization in RAPS  
NUM_CAL = 15            # number of calibration/test splits

def generate_prediction_sets(val_smx, qhat, weighted, func_type, raps_params, rand):
    
    assert func_type in ['LAC', 'APS', 'RAPS'], "score function type not implemented"
    
    if func_type == 'LAC':
        val_smx = 1 - val_smx
        if not weighted:
            prediction_sets = val_smx <= qhat
        else:
            prediction_sets = val_smx <= qhat[:]
    elif func_type == 'APS':
        val_pi = val_smx.argsort(1)[:, ::-1]
        val_srt = np.take_along_axis(val_smx, val_pi, axis=1).cumsum(axis=1)
        if not weighted:
            prediction_sets = np.take_along_axis(val_srt <= qhat, val_pi.argsort(axis=1), axis=1)
        else:
            prediction_sets = np.take_along_axis(val_srt <= np.tile(qhat, (1,val_srt.shape[1])), val_pi.argsort(axis=1), axis=1)
    else:
        lam_reg = raps_params[0]
        k_reg = raps_params[1]
        reg_vec = np.array(k_reg*[0,] + (val_smx.shape[1]-k_reg)*[lam_reg,])[None,:]
        val_pi = val_smx.argsort(1)[:, ::-1]
        val_srt = np.take_along_axis(val_smx, val_pi, axis=1)
        val_srt_reg = val_srt + reg_vec
        n_val = val_smx.shape[0]
        if rand:
            val_srt_reg_cumsum = val_srt_reg.cumsum(axis=1) - np.random.rand(n_val,1) * val_srt_reg
        else: 
            val_srt_reg_cumsum = val_srt_reg.cumsum(axis=1) - val_srt_reg
        if not weighted:
            prediction_sets = np.take_along_axis(val_srt_reg_cumsum <= qhat, val_pi.argsort(axis=1), axis=1)
        else:
            prediction_sets = np.take_along_axis(val_srt_reg_cumsum <= np.tile(qhat, (1,val_srt.shape[1])), val_pi.argsort(axis=1), axis=1)
    return prediction_sets

def score_function(smx, y, func_type, raps_params):
    
    assert func_type in ['LAC', 'APS', 'RAPS'], "score function type not implemented"
    
    if func_type == 'LAC':
        cal_scores = 1 - smx[np.arange(smx.shape[0]),y]
    elif func_type == 'APS':
        cal_pi = smx.argsort(1)[:,::-1] 
        cal_srt = np.take_along_axis(smx,cal_pi,axis=1).cumsum(axis=1)
        cal_scores = np.take_along_axis(cal_srt,cal_pi.argsort(axis=1),axis=1)[range(cal_srt.shape[0]), y]
    else:
        lam_reg = raps_params[0]
        k_reg = raps_params[1]
        rand = raps_params[2]
        n = smx.shape[0]
        reg_vec = np.array(k_reg*[0, ] + (smx.shape[1] - k_reg) * [lam_reg,])[None,:]
        cal_pi = smx.argsort(1)[:,::-1] 
        cal_srt = np.take_along_axis(smx, cal_pi, axis=1)
        cal_srt_reg = cal_srt + reg_vec
        cal_L = np.where(cal_pi == y[:,None])[1]
        if rand:
            cal_scores = cal_srt_reg.cumsum(axis=1)[np.arange(n),cal_L] - np.random.rand(n) * cal_srt_reg[np.arange(n),cal_L]
        else:
            cal_scores = cal_srt_reg.cumsum(axis=1)[np.arange(n),cal_L]
    return cal_scores

def calculate_scores(subclass_ratio, smx_list, save_dir, foldername_to_class, CP_files, score_func_type, RAPS_params):
    
    cal_score_list = []
    for domain in subclass_ratio:
        cal_img_list = []
        cal_labels = []
        cal_smx = np.empty((0,1000), dtype=np.float64)
        for subclass in subclass_ratio[domain]:
            subdir_path = os.path.join(VAL_DIR, subclass)
            images = [f for f in CP_files[subclass]]
            for img_path in images:
                cal_smx = np.vstack((cal_smx, smx_list[img_path]))
            cal_labels += [foldername_to_class[subclass]] * len(images)
        cal_labels = np.array(cal_labels, dtype=np.uint16)
        cal_scores = score_function(cal_smx, cal_labels, score_func_type, RAPS_params)
        
        np.save(os.path.join(save_dir, f"{domain}.npy"), cal_scores)
        cal_score_list.append(cal_scores)
    return cal_score_list

# read subclass info from file
with open(os.path.join(DOMAIN_CLASSIFIER_DIR, "subclass_ratio.json")) as json_file:
    subclass_ratio = json.load(json_file)
K = len(subclass_ratio)
print(f"{K} domains")

#################### map foldername to domain label ############################
foldername_to_class = {}
with open(os.path.join(INFO_DIR, "dataset_class_info.json"), "r") as file:
    dataset_class_info = json.load(file)
for info in dataset_class_info:
    foldername_to_class[info[1]] = info[0]
################################################################################

random.seed(SEED)
for cal_i in range(NUM_CAL):

    print(f"calibration/test split {cal_i}")
    cal_score_dir = os.path.join(SCORES_DIR, f"{cal_i}")

    if not os.path.exists(cal_score_dir):
        os.makedirs(cal_score_dir)
        # create CP files and test files which store the image paths in a dictionary
        CP_files = {}
        test_files = {}
        for domain in subclass_ratio:
            for subclass in subclass_ratio[domain]:
                subdir_path = os.path.join(VAL_DIR, subclass)
                images = [f for f in os.listdir(subdir_path) if f.endswith('.JPEG')]
                random.shuffle(images)
                cal_images = images[:int(len(images) * 0.5)]
                test_images = images[int(len(images) * 0.5):]
                CP_files[subclass] = cal_images
                test_files[subclass] = test_images
        print(len(CP_files))
        with open(os.path.join(cal_score_dir, f"CP_files.json"), "w") as f:
            json.dump(CP_files, f, indent=4)
        with open(os.path.join(cal_score_dir, f"test_files.json"), "w") as f:
            json.dump(test_files, f, indent=4)
    
    save_dir = os.path.join(cal_score_dir, SCORE_FUNC_TYPE)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # load pre computed softmax output of the prediction model. In a dictionary 
    # where the key is the image name and the value is the softmax output as numpy
    with open(os.path.join(VAL_DIR, f"predictions_{ARCH_TYPE}.pkl"), 'rb') as f:
        smx_list = pickle.load(f)
    
    # load CP file names
    with open(os.path.join(cal_score_dir, f"CP_files.json"), "rb") as f:
        CP_files = json.load(f)    

    cal_scores_predicted = calculate_scores(subclass_ratio, smx_list, save_dir, foldername_to_class, CP_files, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))