from argparse import ArgumentParser
from torchvision import models
import utils
from tqdm import tqdm
import torch
import os.path as path
from scipy import stats
from sklearn.metrics import ndcg_score

#This script contains the logic to calculate the Kendall-Tau coefficient of the rankings of the activations of a layer
# Can be run as its own script or imported as a module.
# Does not calculate the activations on its own - that is handled in get_imagenet_activations_by_channel.py


#Loads in the initial and final activations only from the activations saved by get_imagenet_activations_by_channel.py
def load_in_activations(output_folder, steps):

    init_activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{steps[0]}.pt")
    final_activations_save_path = path.join(output_folder, f"model_checkpoint_activations_step_{steps[-1]}.pt")

    init_activations = torch.load(init_activations_save_path)
    final_activations = torch.load(final_activations_save_path)
    return init_activations, final_activations

def load_in_activations_from_paths(init_path, final_path):

    init_activations_save_path = path.join(init_path)
    final_activations_save_path = path.join(final_path)

    init_activations = torch.load(init_activations_save_path)
    final_activations = torch.load(final_activations_save_path)
    return init_activations, final_activations

def generate_ndcg_results(output_folder, steps):
    init_activations, final_activations = load_in_activations(output_folder, steps)
    print(f"init and final activation shapes: {init_activations.shape}, {final_activations.shape}")


    #init_ranks = init_activations.argsort(dim=1)
    #final_ranks = final_activations.argsort(dim=1)

    print(f"init and final activations shapes: {init_activations.shape}, {final_activations.shape}")

    # Need to compare initial to itself and initial to final
    compute_ndcg(init_activations, init_activations, output_folder, name='init_init')
    compute_ndcg(init_activations, final_activations, output_folder, name='init_final')

def compute_ndcg(activations_a, activations_b, output_folder, name):
    num_channels = activations_a.shape[0]
    ndcg_scores = torch.zeros([num_channels,num_channels])
    #normalized_a = activations_a/activations_a.max()
    #normalized_b = activations_b/activations_b.max()
    count = 0
    for channel_a in range(num_channels):
        pruned_a = activations_a[channel_a]
        mask = pruned_a > 0.1 * pruned_a.max()
        pruned_a = pruned_a[mask]
        for channel_b in range(num_channels):
            if count % 1000 == 0:
                print(f'step {count} of {num_channels * num_channels} for calculating ndcg')
            pruned_b = activations_b[channel_b]
            pruned_b = pruned_b[mask]
            #print(pruned_a.shape, pruned_b.shape, 'SHAPES')

            ndcg_scores[channel_a, channel_b] = ndcg_score(pruned_a.unsqueeze(0).cpu().numpy(), pruned_b.unsqueeze(0).cpu().numpy())
            count += 1

    torch.save(ndcg_scores, path.join(output_folder, f'ndcg_{name}.pt'))
    print(f'max of scores - scores.t: {(ndcg_scores-ndcg_scores.T).max()}')


# Gets the kt of 2 vectors
def get_kt_from_vectors(data1, data2):
    res = stats.kendalltau(data1.cpu(), data2.cpu())
    cor = res.correlation
    return cor

def generate_kt_results(init_activations, final_activations,output_folder):
    print(f"init and final activation shapes: {init_activations.shape}, {final_activations.shape}")


    #init_ranks = init_activations.argsort(dim=1)
    #final_ranks = final_activations.argsort(dim=1)

    #print(f"init and final activations shapes: {init_activations.shape}, {final_activations.shape}")

    # Need to compare initial to itself and initial to final
    calc_kendall_tau(init_activations, init_activations, output_folder, name='init_init')
    calc_kendall_tau(init_activations, final_activations, output_folder, name='init_final')

#takes two tensors with the activations ranks and returns two tensors
# first returned has t[i,j] = kendall tau correlation value between channels i and j
# second is the p values associated with the kendall tau
def calc_kendall_tau(ranks_a, ranks_b, output_folder, name):
    num_channels = ranks_a.shape[0]
    correlations = torch.zeros([num_channels, num_channels])
    pvalues = torch.zeros([num_channels, num_channels])
    print('correlations shape: ', correlations.shape)
    print('performing kendall tau analysis')

    count=0

    for i in tqdm(range(num_channels)):
        channel_a = ranks_a[i]
        #mask_a = channel_a > 0.05 * channel_a.max()
        q = .95
        mask_a = channel_a > torch.quantile(channel_a, q)
        mask_a_indices = (mask_a).nonzero().squeeze()
        for j in range(i, num_channels):
            # print(i,j)
            if count % 1000 == 0:
                print(f'step {count} of {num_channels * (num_channels+1)/2}')
            
            channel_b = ranks_b[j]
            #try:
            #mask_b = channel_b > 0.05 * channel_b.max()
            mask_b = channel_b > torch.quantile(channel_b, q)
            mask_b_indices = (mask_b).nonzero().squeeze()
            #print(j, mask_a_indices, mask_b, channel_b[mask_b])
            mask = (torch.cat((mask_a_indices,mask_b_indices))).unique()
            # except:
            #     #print("problem with mask!!!!")
            #     #print(channel_b, channel_b.max()) I'm taking am smaller threshold...
            #     mask_b = channel_b > 0.05 * channel_b.max()
            #     mask_b_indices = (mask_b).nonzero().squeeze()

            #     mask = (torch.cat((mask_a_indices,mask_b_indices))).unique()
            channel_a_pruned = channel_a[mask].cpu().numpy()
            channel_b_pruned = channel_b[mask].cpu().numpy()
            
            res = stats.kendalltau(channel_a_pruned, channel_b_pruned)
            cor = res.correlation
            pvalue = res.pvalue
            correlations[i, j] = cor
            correlations[j, i] = cor
            pvalues[i, j] = pvalue
            pvalues[j, i] = pvalue
            count+=1
    torch.save(correlations, path.join(output_folder, f'correlations_{name}.pt'))
    torch.save(pvalues, path.join(output_folder, f'pvalues_{name}.pt'))


def calc_kendall_tau_old(ranks_a, ranks_b, output_folder, name):
    num_channels = ranks_a.shape[0]
    correlations = torch.zeros([num_channels, num_channels])
    pvalues = torch.zeros([num_channels, num_channels])
    print('correlations shape: ', correlations.shape)
    print('performing kendall tau analysis')

    count=0

    for i in range(num_channels):
        pruned_a = ranks_a[i]
        mask = pruned_a > 0.1 * pruned_a.max()
        pruned_a = pruned_a[mask]
        for j in range(num_channels):
            # print(i,j)
            if count % 1000 == 0:
                print(f'step {count} of {num_channels * num_channels}')
            pruned_b = ranks_b[j]

            pruned_b = pruned_b[mask]

            res = stats.kendalltau(pruned_a.cpu(), pruned_b.cpu())
            cor = res.correlation
            pvalue = res.pvalue
            correlations[i, j] = cor
            #correlations[j, i] = cor
            pvalues[i, j] = pvalue
            #pvalues[j, i] = pvalue
            count+=1

    torch.save(correlations, path.join(output_folder, f'correlations_{name}.pt'))
    out =path.join(output_folder, f'correlations_{name}.pt')
    print(f'saving to {out}')
    torch.save(pvalues, path.join(output_folder, f'pvalues_{name}.pt'))
#This version using torchmetrics was even slower than numpy
# def calc_kendall_tau(ranks_a, ranks_b, save_path=''):
#     num_channels = ranks_a.shape[0]
#     correlations = torch.zeros([num_channels, num_channels])
#     pvalues = torch.zeros([num_channels, num_channels])
#     print('correlations shape: ', correlations.shape)
#     print('performing kendall tau analysis')
#     for i in tqdm(range(num_channels)):
#         for j in range(i, num_channels):
#             print(i,j)
#             res = kendall_rank_corrcoef(ranks_a[i], ranks_b[j], t_test=True)
#             cor = res[0]
#             pvalue = res[1]
#             correlations[i, j] = cor
#             correlations[j, i] = cor
#             pvalues[i, j] = pvalue
#             pvalues[j, i] = pvalue
#     return correlations, pvalues


# def get_cos_sim_results(output_folder, init_top_images, final_atop_images):
#     sim_dict = {}
#     last_path = path.join(output_folder, f'cos_sim_init_final.pt')
#     if path.exists(last_path):
#         print('loading precalculated cosine similarity stats')
#     else:
#         print('Cosine Similarity stats not found, calculating directly')
#         generate_cos_sim_results(init_top_images, final_top_images, output_folder)

#     cos_sim_ii = torch.load(path.join(output_folder, f'cos_ii.pt'))
#     cos_sim_if = torch.load(path.join(output_folder, f'cos_if.pt'))

#     sim_dict['ii'] = cos_sim_ii
#     sim_dict['if'] = cos_sim_if
    
#     return sim_dict

# def generate_cos_sim_results(init_activations, final_activations,output_folder):
#     print(f"init and final activation shapes: {init_activations.shape}, {final_activations.shape}")


#     #init_ranks = init_activations.argsort(dim=1)
#     #final_ranks = final_activations.argsort(dim=1)
#     # Need to compare initial to itself and initial to final
#     calc_cos_sim(init_activations, init_activations, output_folder, name='ii')
#     calc_cos_sim(init_activations, final_activations, output_folder, name='if')

# #takes two tensors with the activations ranks and returns two tensors
# # first returned has t[i,j] = kendall tau correlation value between channels i and j
# # second is the p values associated with the kendall tau
# def calc_cos_sim(ranks_a, ranks_b, output_folder, name):
#     num_channels = ranks_a.shape[0]
#     cos_sim = torch.zeros([num_channels, num_channels])
#     print('correlations shape: ', correlations.shape)
#     print('performing kendall tau analysis')

#     count=0

#     for i in range(num_channels):
#         pruned_a = ranks_a[i]
#         mask = pruned_a > 0.1 * pruned_a.max()
#         pruned_a = pruned_a[mask]
#         for j in range(num_channels):
#             # print(i,j)
#             if count % 1000 == 0:
#                 print(f'step {count} of {num_channels * num_channels}')
#             pruned_b = ranks_b[j]

#             pruned_b = pruned_b[mask]

#             res = stats.kendalltau(pruned_a.cpu(), pruned_b.cpu())
#             cor = res.correlation
#             pvalue = res.pvalue
#             correlations[i, j] = cor
#             #correlations[j, i] = cor
#             pvalues[i, j] = pvalue
#             #pvalues[j, i] = pvalue
#             count+=1
#     torch.save(correlations, path.join(output_folder, f'correlations_{name}.pt'))
#     torch.save(pvalues, path.join(output_folder, f'pvalues_{name}.pt'))

#Use this to get the KT stats for other scripts
def get_overall_kt_results(output_folder, init_activations_path, final_activations_path, do_overwrite = False):
    kt_dict = {}
    pvalues_path = path.join(output_folder, f'pvalues_init_final.pt')
    if path.exists(pvalues_path) and not do_overwrite:
        print('loading precalculated KT stats')
    else:
        print('KT stats not found, calculating directly')
        init_activations, final_activations = load_in_activations_from_paths(init_activations_path, final_activations_path)
        generate_kt_results(init_activations, final_activations, output_folder)

    correlations_ii = torch.load(path.join(output_folder, f'correlations_init_init.pt'))
    correlations_if = torch.load(path.join(output_folder, f'correlations_init_final.pt'))

    pvalues_ii = torch.load(path.join(output_folder, f'pvalues_init_init.pt'))
    pvalues_if = torch.load(path.join(output_folder, f'pvalues_init_final.pt'))
    kt_dict['cor_ii'] = correlations_ii
    kt_dict['cor_if'] = correlations_if
    kt_dict['p_ii'] = pvalues_ii
    kt_dict['p_if'] = pvalues_if
    return kt_dict


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    #parser.add_argument("--num-img-to-save", type=int, default=10)
    #parser.add_argument("--num-top-classes", type=int, default=6)
    #parser.add_argument("--num-indices-to-track", type=int, default=10)

    args = parser.parse_args()

    steps, layers, config_dict, channels, output_folder = utils.get_run_config(args.results_directory)

    init_activations, final_activations = load_in_activations(output_folder, steps)
    generate_kt_results(init_activations, final_activations, output_folder)
