import numpy as np
import matplotlib.pyplot as plt
import random
import sys
import os
import json
import torch
import pickle
import time
from sklearn.metrics.pairwise import cosine_similarity
from scipy.special import softmax

'''
    This script generates the coverage results across different test environments for A3.
'''

SEED = 10
ARCH_TYPE = sys.argv[1]           # vit, resnet50, or clip
if float(sys.argv[2]) == 0.1:
    D_ALPHA = 0.1               
elif float(sys.argv[2]) == 1:
    D_ALPHA = 1
else:
    D_ALPHA = 10
LEVEL = 3                   # parameter for BREEDS, 3
NUM_CLASSES = sys.argv[3]             # parameter for BREEDS, 3 or 17
SCORE_FUNC_TYPE = sys.argv[4]     # LAC, APS, or RAPS
TEMP = 0.4                    # temperature for softmax          
TOP_K_PERCENT = 0.1         # percentage of the calibration data with top distance/similarity measures to keep
'''
if int(sys.argv[5]) == 0:
    WEIGHTED = False
else:
    WEIGHTED = True
'''

WEIGHTED = True
MAIN_DIR = ''                                      
VAL_DIR = os.path.join(MAIN_DIR, 'data/val_data')                                                             # directory where validation data is stored
DOMAIN_CLASSIFIER_DIR = os.path.join(MAIN_DIR, f'best_domain_classifiers_{LEVEL}_{NUM_CLASSES}/{ARCH_TYPE}')         # directory where trained domain classifier is stored
SCORES_DIR = os.path.join(MAIN_DIR, f"calibration_scores/{ARCH_TYPE}/{LEVEL}_{NUM_CLASSES}")                         # directory where pre-computed scores is stored
INFO_DIR = os.path.join(MAIN_DIR, 'imagenet_class_hierarchy/modified')                                               # directory where class info is stored
RESULTS_DIR = os.path.join(MAIN_DIR, f'results/{LEVEL}_{NUM_CLASSES}')                                               # directory to store the results
ALPHA_LIST = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]            # parameter for conformal prediction thresholds                                        

LAM_REG = 0.01                                                                  # parameter for RAPS
K_REG = 5                                                                       # parameter for RAPS
RAND = True                                                                     # parameter for RAPS
RAPS_PARAMS = [LAM_REG, K_REG]
NUM_CAL = 15                                                                    # number of calibration/test splits
NUM_TEST = 100                                                                  # number of test environments
SIM_TYPE = 'cosine'     

if ARCH_TYPE == 'vit':
    embedding_size = 768
elif ARCH_TYPE == 'resnet50':
    embedding_size = 2048
else:
    embedding_size = 768

if not os.path.exists(RESULTS_DIR):
    os.makedirs(RESULTS_DIR)


def score_function(smx, y, func_type, raps_params):
    
    assert func_type in ['LAC', 'APS', 'RAPS'], "score function type not implemented"
    
    if func_type == 'LAC':
        cal_scores = 1 - smx[np.arange(smx.shape[0]),y]
    elif func_type == 'APS':
        cal_pi = smx.argsort(1)[:,::-1] 
        cal_srt = np.take_along_axis(smx,cal_pi,axis=1).cumsum(axis=1)
        cal_scores = np.take_along_axis(cal_srt,cal_pi.argsort(axis=1),axis=1)[range(cal_srt.shape[0]), y]
    else:
        lam_reg = raps_params[0]
        k_reg = raps_params[1]
        rand = raps_params[2]
        n = smx.shape[0]
        reg_vec = np.array(k_reg*[0, ] + (smx.shape[1] - k_reg) * [lam_reg,])[None,:]
        cal_pi = smx.argsort(1)[:,::-1] 
        cal_srt = np.take_along_axis(smx, cal_pi, axis=1)
        cal_srt_reg = cal_srt + reg_vec
        cal_L = np.where(cal_pi == y[:,None])[1]
        if rand:
            cal_scores = cal_srt_reg.cumsum(axis=1)[np.arange(n),cal_L] - np.random.rand(n) * cal_srt_reg[np.arange(n),cal_L]
        else:
            cal_scores = cal_srt_reg.cumsum(axis=1)[np.arange(n),cal_L]
    return cal_scores

def calculate_scores(subclass_ratio, smx_list, save_dir, foldername_to_class, CP_files, score_func_type, RAPS_params):
    
    cal_score_list = []
    for domain in subclass_ratio:
        cal_img_list = []
        cal_labels = []
        cal_smx = np.empty((0,1000), dtype=np.float64)
        for subclass in subclass_ratio[domain]:
            subdir_path = os.path.join(VAL_DIR, subclass)
            images = [f for f in CP_files[subclass]]
            for img_path in images:
                cal_smx = np.vstack((cal_smx, smx_list[img_path]))
            cal_labels += [foldername_to_class[subclass]] * len(images)
        cal_labels = np.array(cal_labels, dtype=np.uint16)
        cal_scores = score_function(cal_smx, cal_labels, score_func_type, RAPS_params)
        np.save(os.path.join(save_dir, f"{domain}.npy"), cal_scores)
        cal_score_list.append(cal_scores)
    return cal_score_list
    
def pairwise_sim(A, B, sim_type):
    
    n = A.shape[0]      # number of calibration data + 1
    m = B.shape[0]      # number of test data
    
    assert sim_type in ['euclid', 'dot', 'cosine'], "similarity type not implemented"
    
    if sim_type == 'euclid':
        pair_diff = np.transpose(A[:,:,None] - B.T[None,:,:], (2, 0, 1))
        distance = np.linalg.norm(pair_diff, ord=2, axis=-1)
    elif sim_type == 'dot':
        distance = (A @ B.T).T
    else:
        distance = cosine_similarity(A, B).T

    return distance

def calculate_threshold(score_list, weights, alpha, sim_type):
    
    cal_scores = np.array([])
    k = 0
    for scores in score_list:
        cal_scores = np.concatenate((cal_scores, scores))

    cal_scores = np.tile(cal_scores, (weights.shape[0],1))
    cal_pi = cal_scores.argsort(1)[::-1]
    cal_srt = np.take_along_axis(cal_scores, cal_pi, axis=1)
    w_k_srt = np.take_along_axis(weights, cal_pi, axis=1).cumsum(axis=1)
    index = w_k_srt < (1-alpha)
    index = index.sum(axis=1)
    return np.take_along_axis(cal_srt, index[:,None], axis=1), index

def generate_prediction_sets(val_smx, qhat, weighted, func_type, raps_params):
    
    assert func_type in ['LAC', 'APS', 'RAPS'], "score function type not implemented"
    
    if func_type == 'LAC':
        val_smx = 1 - val_smx
        if not weighted:
            prediction_sets = val_smx <= qhat
        else:
            prediction_sets = val_smx <= qhat[:]
    elif func_type == 'APS':
        val_pi = val_smx.argsort(1)[:, ::-1]
        val_srt = np.take_along_axis(val_smx, val_pi, axis=1).cumsum(axis=1)
        if not weighted:
            prediction_sets = np.take_along_axis(val_srt <= qhat, val_pi.argsort(axis=1), axis=1)
        else:
            prediction_sets = np.take_along_axis(val_srt <= np.tile(qhat, (1,val_srt.shape[1])), val_pi.argsort(axis=1), axis=1)
    else:
        lam_reg = raps_params[0]
        k_reg = raps_params[1]
        rand = raps_params[2]
        reg_vec = np.array(k_reg*[0,] + (val_smx.shape[1]-k_reg)*[lam_reg,])[None,:]
        val_pi = val_smx.argsort(1)[:, ::-1]
        val_srt = np.take_along_axis(val_smx, val_pi, axis=1)
        val_srt_reg = val_srt + reg_vec
        n_val = val_smx.shape[0]
        if rand:
            val_srt_reg_cumsum = val_srt_reg.cumsum(axis=1) - np.random.rand(n_val,1) * val_srt_reg
        else: 
            val_srt_reg_cumsum = val_srt_reg.cumsum(axis=1) - val_srt_reg
        if not weighted:
            prediction_sets = np.take_along_axis(val_srt_reg_cumsum <= qhat, val_pi.argsort(axis=1), axis=1)
        else:
            prediction_sets = np.take_along_axis(val_srt_reg_cumsum <= np.tile(qhat, (1,val_srt.shape[1])), val_pi.argsort(axis=1), axis=1)
    return prediction_sets


# read subclass info from file
with open(os.path.join(DOMAIN_CLASSIFIER_DIR, "subclass_ratio.json")) as json_file:
    subclass_ratio = json.load(json_file)

subclass_to_superclass = {}
for domain in subclass_ratio:
    for subclass in subclass_ratio[domain]:
        subclass_to_superclass[subclass] = domain
K = len(subclass_ratio)
print(f"{K} domains")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#################### map foldername to domain label ############################
foldername_to_class = {}
with open(os.path.join(INFO_DIR, "dataset_class_info.json"), "r") as file:
    dataset_class_info = json.load(file)
for info in dataset_class_info:
    foldername_to_class[info[1]] = info[0]
################################################################################

####################### load predictions/weights ###############################
with open(os.path.join(VAL_DIR, f"predictions_{ARCH_TYPE}.pkl"), 'rb') as f:
    smx_list = pickle.load(f)
with open(os.path.join(VAL_DIR, f"embeddings_{ARCH_TYPE}.pkl"), 'rb') as f:
    embedding_list = pickle.load(f)
with open(os.path.join(VAL_DIR, f"weights_{ARCH_TYPE}_{LEVEL}_{NUM_CLASSES}.pkl"), "rb") as f:
    weights_list = pickle.load(f)
################################################################################


for alpha in ALPHA_LIST:
    
    coverage_dict = {str(i): {str(j): {} for j in range(NUM_CAL)} for i in range(NUM_TEST)}
    
    for cal_i in range(NUM_CAL):
    
        ######################## load cal/test split info ##############################
        
        
        with open(os.path.join(SCORES_DIR, str(cal_i), f"CP_files.json"), "rb") as f:
            CP_files = json.load(f)

        ################################################################################

        ######################## load calibration scores ###############################
        cal_score_dir = os.path.join(SCORES_DIR, str(cal_i), SCORE_FUNC_TYPE)
        cal_scores = []         # store calibration scores from each domain in a list
        cal_images = []         # store the list of calibration image file names
        for domain in subclass_ratio:
            cal_scores.append(np.load(os.path.join(cal_score_dir, f"{domain}.npy")))
            
            for subclass in subclass_ratio[domain]:
                #random.shuffle(CP_files[subclass])
                cal_images += CP_files[subclass]
                
        cal_embeddings = np.empty((len(cal_images), embedding_size), dtype=np.float64)
        for data_i in range(cal_embeddings.shape[0]):
            cal_embeddings[data_i] = embedding_list[cal_images[data_i]]
        ################################################################################
        
        ############################ load test file names ##############################
        with open(os.path.join(SCORES_DIR, str(cal_i), f"test_files.json"), "rb") as f:
            test_files = json.load(f)
        ################################################################################
        
        ######################## load test distribution info ###########################
        with open(os.path.join(RESULTS_DIR, 'test_distributions.pkl'), "rb") as f:
            distribution_list = pickle.load(f)
        ################################################################################
        
        ############### save/load file names for each distribution #####################
        test_images = {}
        with open(os.path.join(RESULTS_DIR, f'test_images/{cal_i}/filenames_{D_ALPHA}.json'), 'r') as f:
            test_images[str(D_ALPHA)] = json.load(f)
        ################################################################################
        
        prediction_sets_dict = {str(i): {} for i in range(NUM_TEST)}
        for test_i in range(NUM_TEST):
            
            # select the images from directory
            selected_images = test_images[str(D_ALPHA)][str(test_i)]['images']
            n_test = len(selected_images)
            test_img_paths = []
            test_labels = []
            test_smx = np.empty((n_test,1000), dtype=np.float64)
            test_weights = np.empty((0,K), dtype=np.float64)
            test_embeddings = np.empty((n_test, embedding_size), dtype=np.float64)
            for data_i in range(len(selected_images)):
                subclass = selected_images[data_i].split('/')[-2]
                img_path = selected_images[data_i].split('/')[-1]
                test_smx[data_i] = smx_list[img_path]
                test_embeddings[data_i] = embedding_list[img_path]
                test_weights = np.vstack((test_weights, weights_list[img_path]))
                test_img_paths.append(img_path)
                test_labels.append(foldername_to_class[subclass])
            test_labels = np.array(test_labels, dtype=np.uint16)
            
            ####################### A3 w/ distribution shifts ##############################
            print(f"{test_i}, {cal_i}, {alpha}, A3")
            
            pairwise_distance = pairwise_sim(cal_embeddings, test_embeddings, SIM_TYPE)
            
            if SIM_TYPE == 'euclid':
                thresholds = np.percentile(pairwise_distance, 100*TOP_K_PERCENT, axis=1, keepdims=True)
                pairwise_distance = np.where(pairwise_distance > thresholds, np.inf, pairwise_distance)
                pairwise_distance = np.pad(pairwise_distance, pad_width=((0, 0), (0, 1)), mode='constant', constant_values=0)
                weights = softmax(-1*pairwise_distance/TEMP, axis=1)
            else:
                thresholds = np.percentile(pairwise_distance, 100*(1 - TOP_K_PERCENT), axis=1, keepdims=True)
                pairwise_distance = np.where(pairwise_distance <= thresholds, -np.inf, pairwise_distance)
                pairwise_distance = np.pad(pairwise_distance, pad_width=((0, 0), (0, 1)), mode='constant', constant_values=1)
                weights = softmax(pairwise_distance/TEMP, axis=1)

                    
            weights = weights[:,:-1]
            
            qhat, index = calculate_threshold(cal_scores, weights, alpha, SIM_TYPE)
            
            prediction_sets = generate_prediction_sets(test_smx, qhat, True, SCORE_FUNC_TYPE, (LAM_REG, K_REG, RAND))
            
            for c in range(prediction_sets.shape[0]):
                prediction_sets_dict[str(test_i)][test_img_paths[c]] = prediction_sets[c]
            coverage = prediction_sets[np.arange(prediction_sets.shape[0]), test_labels].mean()
            set_size = np.sum(prediction_sets, axis=1).mean()
            print(f"A3 coverage distribution {test_i}: {coverage}")
            print(f"A3 set size distribution {test_i}: {set_size}")
            coverage_dict[str(test_i)][str(cal_i)]['coverage'] = coverage
            coverage_dict[str(test_i)][str(cal_i)]['set size'] = set_size
            ################################################################################
        
        SAVE_DIR = os.path.join(RESULTS_DIR, f"coverage_results_A3_new_{TEMP}", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(D_ALPHA), "prediction_sets", str(cal_i))
        if not os.path.exists(SAVE_DIR):
            os.makedirs(SAVE_DIR)
        with open(os.path.join(SAVE_DIR, f"prediction_sets_{WEIGHTED}.pkl"), "wb") as f:
            pickle.dump(prediction_sets_dict, f)
    if not os.path.exists(os.path.join(RESULTS_DIR, f"coverage_results_A3_new_{TEMP}", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(D_ALPHA))):
        os.makedirs(os.path.join(RESULTS_DIR, f"coverage_results_A3_new_{TEMP}", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(D_ALPHA)))
    with open(os.path.join(RESULTS_DIR, f"coverage_results_A3_new_{TEMP}", ARCH_TYPE, SCORE_FUNC_TYPE, str(alpha), str(D_ALPHA), f"coverage_{WEIGHTED}.json"), "w") as f:
        json.dump(coverage_dict, f, indent=4)