import numpy as np
import random
import sys
import os
import json
import torch
import pickle
import time
from scipy.special import softmax

'''
    This script generates the coverage results across different test environments for standard CP, A1, A2, oracle, and max methods.
'''

SEED = 10
ARCH_TYPE = sys.argv[1]         # 'vit', 'resnet50', 'clip'
LEVEL = sys.argv[2]             # level of the class hierarchy for BREEDS
NUM_CLASSES = sys.argv[3]       # number of subclasses in a domain

if float(sys.argv[4]) == 0.1:
    d_alpha = 0.1               # alpha prime for Dirichlet distribution
elif float(sys.argv[4]) == 1:
    d_alpha = 1
else:
    d_alpha = 10

SCORE_FUNC_TYPE = sys.argv[5]   # 'LAC', 'APS', 'RAPS'

MAIN_DIR = ''
MDTS_DIR = os.path.join(MAIN_DIR, f'MDTS_regressor/{LEVEL}_{NUM_CLASSES}')  
VAL_DIR = os.path.join(MAIN_DIR, 'data/val_data')
DOMAIN_CLASSIFIER_DIR = os.path.join(MAIN_DIR, f'best_domain_classifiers_{LEVEL}_{NUM_CLASSES}/{ARCH_TYPE}')
SCORES_DIR = os.path.join(MAIN_DIR, f"calibration_scores/{ARCH_TYPE}/{LEVEL}_{NUM_CLASSES}")
INFO_DIR = os.path.join(MAIN_DIR, 'imagenet_class_hierarchy/modified')          # directory where class info is stored
RESULTS_DIR = os.path.join(MAIN_DIR, f'results/{LEVEL}_{NUM_CLASSES}')
ALPHA_LIST = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
LAM_REG = 0.01      # param for RAPS
K_REG = 5           # param for RAPS
RAND = True         # whether to use randomization in RAPS
RAPS_PARAMS = [LAM_REG, K_REG]
NUM_CAL = 15            # number of calibration/test splits
NUM_TEST = 100          # number of test environments
ARCH_TYPE_LIST = ['unweighted', 'A1', 'A2', 'max', 'oracle']


if ARCH_TYPE == 'vit':
    embedding_size = 768
elif ARCH_TYPE == 'resnet50':
    embedding_size = 2048
else:
    embedding_size = 768



if not os.path.exists(RESULTS_DIR):
    os.makedirs(RESULTS_DIR)


def calculate_threshold(score_list, weights, alpha):
    
    if weights is not None:
        cal_scores = np.empty((0,))
        k = 0
        w_k = np.empty((weights.shape[0], 0)) 
        for scores in score_list:
            cal_scores = np.hstack((cal_scores, scores))
            n_k = len(list(scores))
            w_k = np.hstack((w_k, np.tile(weights[:,k] / (n_k + 1), (n_k, 1)).T))
            k += 1
        cal_scores = np.tile(cal_scores, (weights.shape[0],1))
        w_k = np.array(w_k)
        cal_pi = cal_scores.argsort(1)[::-1]
        cal_srt = np.take_along_axis(cal_scores, cal_pi, axis=1)
        w_k_srt = np.take_along_axis(w_k, cal_pi, axis=1).cumsum(axis=1)
        index = w_k_srt < (1-alpha)
        index = index.sum(axis=1)
        return np.take_along_axis(cal_srt, index[:,None], axis=1)
    else:
        cal_scores = []
        for scores in score_list:
            cal_scores += list(scores)
        cal_scores = np.array(cal_scores)
        n = cal_scores.shape[0]
    
        qhat = np.quantile(cal_scores, np.ceil((n + 1) * (1 - alpha)) / n, method="higher")
        return qhat
        
def calculate_threshold_A3(score_list, weights, alpha):
        
    cal_scores = np.array([])
    k = 0
    for scores in score_list:
        cal_scores = np.concatenate((cal_scores, scores))

    cal_scores = np.tile(cal_scores, (weights.shape[0],1))
    cal_pi = cal_scores.argsort(1)[::-1]
    cal_srt = np.take_along_axis(cal_scores, cal_pi, axis=1)
    w_k_srt = np.take_along_axis(weights, cal_pi, axis=1).cumsum(axis=1)
    index = w_k_srt < (1-alpha)
    index = index.sum(axis=1)
    return np.take_along_axis(cal_srt, index[:,None], axis=1)

def generate_prediction_sets(val_smx, qhat, weighted, func_type, raps_params):
    
    assert func_type in ['LAC', 'APS', 'RAPS'], "score function type not implemented"
    
    if func_type == 'LAC':
        val_smx = 1 - val_smx
        if not weighted:
            prediction_sets = val_smx <= qhat
        else:
            prediction_sets = val_smx <= qhat[:]
    elif func_type == 'APS':
        val_pi = val_smx.argsort(1)[:, ::-1]
        val_srt = np.take_along_axis(val_smx, val_pi, axis=1).cumsum(axis=1)
        if not weighted:
            prediction_sets = np.take_along_axis(val_srt <= qhat, val_pi.argsort(axis=1), axis=1)
        else:
            prediction_sets = np.take_along_axis(val_srt <= np.tile(qhat, (1,val_srt.shape[1])), val_pi.argsort(axis=1), axis=1)
    else:
        lam_reg = raps_params[0]
        k_reg = raps_params[1]
        rand = raps_params[2]
        reg_vec = np.array(k_reg*[0,] + (val_smx.shape[1]-k_reg)*[lam_reg,])[None,:]
        val_pi = val_smx.argsort(1)[:, ::-1]
        val_srt = np.take_along_axis(val_smx, val_pi, axis=1)
        val_srt_reg = val_srt + reg_vec
        n_val = val_smx.shape[0]
        if rand:
            val_srt_reg_cumsum = val_srt_reg.cumsum(axis=1) - np.random.rand(n_val,1) * val_srt_reg
        else: 
            val_srt_reg_cumsum = val_srt_reg.cumsum(axis=1) - val_srt_reg
        if not weighted:
            prediction_sets = np.take_along_axis(val_srt_reg_cumsum <= qhat, val_pi.argsort(axis=1), axis=1)
        else:
            prediction_sets = np.take_along_axis(val_srt_reg_cumsum <= np.tile(qhat, (1,val_srt.shape[1])), val_pi.argsort(axis=1), axis=1)
    return prediction_sets

# read subclass info from file
with open(os.path.join(DOMAIN_CLASSIFIER_DIR, "subclass_ratio.json")) as json_file:
    subclass_ratio = json.load(json_file)

subclass_to_superclass = {}
for domain in subclass_ratio:
    for subclass in subclass_ratio[domain]:
        subclass_to_superclass[subclass] = domain
K = len(subclass_ratio)
print(f"{K} domains")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#################### map foldername to domain label ############################
foldername_to_class = {}
with open(os.path.join(INFO_DIR, "dataset_class_info.json"), "r") as file:
    dataset_class_info = json.load(file)
for info in dataset_class_info:
    foldername_to_class[info[1]] = info[0]
################################################################################

################# load the pre-computed lambda values ##########################
rng = np.random.default_rng(seed=SEED)
if not os.path.exists(os.path.join(RESULTS_DIR, 'test_distributions.pkl')):
    print("test distributions not found, generating new ones")
    distribution_list = {}
    for d_alpha in [0.1, 1, 10]:
        distribution_list[d_alpha] = rng.dirichlet(alpha=[d_alpha for _ in range(len(subclass_ratio))], size=NUM_TEST)
    with open(os.path.join(RESULTS_DIR, 'test_distributions.pkl'), "wb") as f:
        pickle.dump(distribution_list, f)
else:
    with open(os.path.join(RESULTS_DIR, 'test_distributions.pkl'), "rb") as f:
        distribution_list = pickle.load(f)
################################################################################

####################### load predictions/weights ###############################
with open(os.path.join(VAL_DIR, f"predictions_{ARCH_TYPE}_{LEVEL}_{NUM_CLASSES}.pkl"), 'rb') as f:
    smx_list = pickle.load(f)
with open(os.path.join(VAL_DIR, f"embeddings_{ARCH_TYPE}_{LEVEL}_{NUM_CLASSES}.pkl"), 'rb') as f:
    embedding_list = pickle.load(f)
with open(os.path.join(VAL_DIR, f"weights_{ARCH_TYPE}_{LEVEL}_{NUM_CLASSES}.pkl"), "rb") as f:
    weights_list = pickle.load(f)
################################################################################



for alpha in ALPHA_LIST:
    
    coverage_dict = {arch: {str(i): {str(j): {} for j in range(NUM_CAL)} for i in range(NUM_TEST)} for arch in ARCH_TYPE_LIST}
    
    for cal_i in range(NUM_CAL):
        SAVE_DIR = os.path.join(RESULTS_DIR, "coverage_results", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(d_alpha), "prediction_sets", str(cal_i))
        corrupted = False
        try:
            with open(os.path.join(SAVE_DIR, "prediction_sets.pkl"), "rb") as f:
                prediction_sets_dict = pickle.load(f)
        except:
            corrupted = True
        
        if not corrupted:
            continue
        
        ######################## load cal/test split info ######################
        with open(os.path.join(SCORES_DIR, str(cal_i), f"CP_files.json"), "rb") as f:
            CP_files = json.load(f)
        ########################################################################
        
        ##################### load calibration scores ##########################
        cal_score_dir = os.path.join(SCORES_DIR, str(cal_i), SCORE_FUNC_TYPE)
        cal_scores = []
        cal_images = []
        for domain in subclass_ratio:
            cal_scores.append(np.load(os.path.join(cal_score_dir, f"{domain}.npy")))
            
            for subclass in subclass_ratio[domain]:
                cal_images += CP_files[subclass]
        
        cal_embeddings = np.empty((len(cal_images), embedding_size), dtype=np.float64)
        for data_i in range(cal_embeddings.shape[0]):
            cal_embeddings[data_i] = embedding_list[cal_images[data_i]]
        
        # compute the qhat of the "hardest" domain
        max_qhat = 0
        for score_k in cal_scores:
            n_k = score_k.shape[0]
            qhat_k = np.quantile(score_k, np.ceil((n_k + 1) * (1 - alpha)) / n_k, method="higher")
            if qhat_k > max_qhat:
                max_qhat = qhat_k
        ################################################################################
        
        ######################## load test file names ###########################
        with open(os.path.join(SCORES_DIR, str(cal_i), f"test_files.json"), "rb") as f:
            test_files = json.load(f)
        ################################################################################
        
        ############### save/load file names for each distribution #####################
        test_images = {}
        if not os.path.exists(os.path.join(RESULTS_DIR, f'test_images/{cal_i}')):
            os.makedirs(os.path.join(RESULTS_DIR, f'test_images/{cal_i}'))

            test_images = {str(i): {'ratio': [], 'images': []} for i in range(NUM_TEST)}

            for i in range(NUM_TEST):
                normalized_distribution = distribution_list[d_alpha][i] / np.max(distribution_list[d_alpha][i])
                domain_num = 0
                for domain in subclass_ratio:
                    for subclass in subclass_ratio[domain]:  
                        subdir_path = os.path.join(VAL_DIR, subclass)
                        images = [f for f in test_files[subclass]]
                        num_samples = len(images)
                        num_samples = int(num_samples * normalized_distribution[domain_num])
                        selected_images = random.sample(images, num_samples)
                        test_images[str(i)]['ratio'] = distribution_list[d_alpha].tolist()
                        test_images[str(i)]['images'] = [os.path.join(subdir_path, img) for img in selected_images]
                domain_num += 1
            with open(os.path.join(RESULTS_DIR, f'test_images/{cal_i}/filenames_{d_alpha}.json'), 'w') as f:
                json.dump(test_images, f, indent=4)
        with open(os.path.join(RESULTS_DIR, f'test_images/{cal_i}/filenames_{d_alpha}.json'), 'r') as f:
            test_images[str(d_alpha)] = json.load(f)
            print(np.array(test_images[str(d_alpha)][str(cal_i)]['ratio']).sum())
        ################################################################################
        
        # compute the oracle qhat for each test environment
        oracle_qhat = [calculate_threshold(cal_scores, distribution_list[d_alpha][test_i][None,:], alpha)[0][0] for test_i in range(NUM_TEST)]
        
        prediction_sets_dict = {arch: {str(i): {} for i in range(NUM_TEST)} for arch in ARCH_TYPE_LIST}
        start_time = time.time() 
        for test_i in range(NUM_TEST):
            print(f"{ARCH_TYPE}, {SCORE_FUNC_TYPE}, {test_i}, {cal_i}, {alpha}")
            # select the images from directory
            selected_images = test_images[str(d_alpha)][str(test_i)]['images']
            test_img_paths = []
            test_labels = []
            test_smx = np.empty((0,1000), dtype=np.float64)
            test_weights = np.empty((0,K), dtype=np.float64)
            test_embeddings = np.empty((len(selected_images), embedding_size), dtype=np.float64)
            for data_i in range(len(selected_images)):
                subclass = selected_images[data_i].split('/')[-2]
                img_path = selected_images[data_i].split('/')[-1]
                test_smx = np.vstack((test_smx, smx_list[img_path]))
                test_embeddings[data_i] = embedding_list[img_path]
                test_weights = np.vstack((test_weights, weights_list[img_path]))
                test_img_paths.append(img_path)
                test_labels.append(foldername_to_class[subclass])
            test_labels = np.array(test_labels, dtype=np.uint16)
            
            ################ unweighted method w/ distribution shifts ######################
            # print(f"{test_i}, {cal_i}, {alpha}, unweighted")
            
            qhat = calculate_threshold(cal_scores, None, alpha)
            
            prediction_sets = generate_prediction_sets(test_smx, qhat, False, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))
            
            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict['unweighted'][str(test_i)][test_img_paths[c]] = prediction_sets[c]
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"unweighted coverage distribution {test_i}: {coverage}")
            print(f"unweighted set size distribution {test_i}: {set_size}")
            coverage_dict['unweighted'][str(test_i)][str(cal_i)]['coverage'] = coverage
            coverage_dict['unweighted'][str(test_i)][str(cal_i)]['set size'] = set_size
            ################################################################################
            
            ####################### A1 w/ distribution shifts ##############################
            # print(f"{test_i}, {cal_i}, {alpha}, A1")
            
            qhat = calculate_threshold(cal_scores, test_weights, alpha)
            
            prediction_sets = generate_prediction_sets(test_smx, qhat, True, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))
            
            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict['A1'][str(test_i)][test_img_paths[c]] = prediction_sets[c]
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"A1 coverage distribution {test_i}: {coverage}")
            print(f"A1 set size distribution {test_i}: {set_size}")
            coverage_dict['A1'][str(test_i)][str(cal_i)]['coverage'] = coverage
            coverage_dict['A1'][str(test_i)][str(cal_i)]['set size'] = set_size
            ################################################################################
            
            
            ####################### A2 w/ distribution shifts ##############################
            # print(f"{test_i}, {cal_i}, {alpha}, A2")
            avg_weights = np.mean(test_weights, axis=0)
            
            qhat = calculate_threshold(cal_scores, avg_weights[None,:], alpha)[0][0]

            prediction_sets = generate_prediction_sets(test_smx, qhat, False, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))

            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict['A2'][str(test_i)][test_img_paths[c]] = prediction_sets[c]
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"A2 coverage distribution {test_i}: {coverage}")
            print(f"A2 set size distribution {test_i}: {set_size}")
            coverage_dict['A2'][str(test_i)][str(cal_i)]['coverage'] = coverage
            coverage_dict['A2'][str(test_i)][str(cal_i)]['set size'] = set_size
            ################################################################################
            '''
            ####################### A3 w/ distribution shifts ##############################
            # print(f"{test_i}, {cal_i}, {alpha}, A3")
            
            num_cal = cal_embeddings.shape[0]
            num_test = test_embeddings.shape[0]
            
            cal_norm = cal_embeddings / np.linalg.norm(cal_embeddings, axis=1, keepdims=True)
            test_norm = test_embeddings / np.linalg.norm(test_embeddings, axis=1, keepdims=True)
            pairwise_distance = np.dot(cal_norm, test_norm.T).T
            
            thresholds = np.percentile(pairwise_distance, 100*(1 - TOP_K_PERCENT), axis=1, keepdims=True)
            pairwise_distance = np.where(pairwise_distance <= thresholds, -np.inf, pairwise_distance)
            pairwise_distance = np.pad(pairwise_distance, pad_width=((0, 0), (0, 1)), mode='constant', constant_values=1)
            print(pairwise_distance.shape)
            sim_weights = softmax(pairwise_distance/0.7, axis=1)
            sim_weights = sim_weights[:,:-1]
            
            qhat = calculate_threshold_A3(cal_scores, sim_weights, alpha)
            
            prediction_sets = generate_prediction_sets(test_smx, qhat, True, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))
            
            '''
            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict['A1'][str(test_i)][test_img_paths[c]] = prediction_sets[c]
            '''
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"A3 coverage distribution {test_i}: {coverage}")
            print(f"A3 set size distribution {test_i}: {set_size}")
            ################################################################################
            '''
            
            
            ################## max threshold w/ distribution shifts ########################
            # print(f"{test_i}, {cal_i}, {alpha}, max")
            
            prediction_sets = generate_prediction_sets(test_smx, max_qhat, False, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))
            
            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict['max'][str(test_i)][test_img_paths[c]] = prediction_sets[c]
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"max coverage distribution {test_i}: {coverage}")
            print(f"max set size distribution {test_i}: {set_size}")
            coverage_dict['max'][str(test_i)][str(cal_i)]['coverage'] = coverage
            coverage_dict['max'][str(test_i)][str(cal_i)]['set size'] = set_size
            ################################################################################
            
        
            ##################### oracle w/ distribution shifts ############################
            # print(f"{test_i}, {cal_i}, {alpha}, oracle")
    
            prediction_sets = generate_prediction_sets(test_smx, oracle_qhat[test_i], False, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))
            
            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict['oracle'][str(test_i)][test_img_paths[c]] = prediction_sets[c]
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"oracle coverage distribution {test_i}: {coverage}")
            print(f"oracle set size distribution {test_i}: {set_size}")
            coverage_dict['oracle'][str(test_i)][str(cal_i)]['coverage'] = coverage
            coverage_dict['oracle'][str(test_i)][str(cal_i)]['set size'] = set_size
            ################################################################################
        
        SAVE_DIR = os.path.join(RESULTS_DIR, "coverage_results_corrupt", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(d_alpha), "prediction_sets", str(cal_i))
        if not os.path.exists(SAVE_DIR):
            os.makedirs(SAVE_DIR)
        with open(os.path.join(SAVE_DIR, "prediction_sets.pkl"), "wb") as f:
            pickle.dump(prediction_sets_dict, f)
        
    if not os.path.exists(os.path.join(RESULTS_DIR, "coverage_results_corrupt", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(d_alpha))):
        os.makedirs(os.path.join(RESULTS_DIR, "coverage_results_corrupt", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(d_alpha)))
    with open(os.path.join(RESULTS_DIR, "coverage_results_corrupt", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(d_alpha), "coverage.json"), "w") as f:
        json.dump(coverage_dict, f, indent=4)