from data_utils import load_dataset
from utils import construct_prompt, random_sampling, construct_prompt_without_test, construct_prompt_instance_prompt_text
import numpy as np
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, LlamaForCompressionCausalLM, AutoConfig
import argparse
from typing import Dict, Optional, Sequence
import itertools
import json
import random
from scipy.spatial import distance
from openpyxl import Workbook

import numpy as np

def gram_linear(x):
    """Compute Gram (kernel) matrix for a linear kernel.

    Args:
        x: A num_examples x num_features matrix of features.

    Returns:
        A num_examples x num_examples Gram matrix of examples.
    """
    return x.dot(x.T)


def gram_rbf(x, threshold=1.0):
    """Compute Gram (kernel) matrix for an RBF kernel.

    Args:
        x: A num_examples x num_features matrix of features.
        threshold: Fraction of median Euclidean distance to use as RBF kernel
        bandwidth. (This is the heuristic we use in the paper. There are other
        possible ways to set the bandwidth; we didn't try them.)

    Returns:
        A num_examples x num_examples Gram matrix of examples.
    """
    dot_products = x.dot(x.T)
    sq_norms = np.diag(dot_products)
    sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
    sq_median_distance = np.median(sq_distances)
    return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))


def center_gram(gram, unbiased=False):
    """Center a symmetric Gram matrix.

    This is equvialent to centering the (possibly infinite-dimensional) features
    induced by the kernel before computing the Gram matrix.

    Args:
        gram: A num_examples x num_examples symmetric matrix.
        unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
        estimate of HSIC. Note that this estimator may be negative.

    Returns:
        A symmetric matrix with centered columns and rows.
    """
    if not np.allclose(gram, gram.T):
        raise ValueError('Input must be a symmetric matrix.')
    gram = gram.copy()

    if unbiased:
        # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
        # L. (2014). Partial distance correlation with methods for dissimilarities.
        # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
        # stable than the alternative from Song et al. (2007).
        n = gram.shape[0]
        np.fill_diagonal(gram, 0)
        means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
        means -= np.sum(means) / (2 * (n - 1))
        gram -= means[:, None]
        gram -= means[None, :]
        np.fill_diagonal(gram, 0)
    else:
        means = np.mean(gram, 0, dtype=np.float64)
        means -= np.mean(means) / 2
        gram -= means[:, None]
        gram -= means[None, :]

    return gram


def cka(gram_x, gram_y, debiased=False):
    """Compute CKA.

    Args:
        gram_x: A num_examples x num_examples Gram matrix.
        gram_y: A num_examples x num_examples Gram matrix.
        debiased: Use unbiased estimator of HSIC. CKA may still be biased.

    Returns:
        The value of CKA between X and Y.
    """
    gram_x = center_gram(gram_x, unbiased=debiased)
    gram_y = center_gram(gram_y, unbiased=debiased)

    # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
    # n*(n-3) (unbiased variant), but this cancels for CKA.
    scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

    normalization_x = np.linalg.norm(gram_x)
    normalization_y = np.linalg.norm(gram_y)
    return scaled_hsic / (normalization_x * normalization_y)


def _debiased_dot_product_similarity_helper( xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n):
    """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
    # This formula can be derived by manipulating the unbiased estimator from
    # Song et al. (2007).
    return (
        xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
        + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))


def feature_space_linear_cka(features_x, features_y, debiased=False):
    """Compute CKA with a linear kernel, in feature space.

    This is typically faster than computing the Gram matrix when there are fewer
    features than examples.

    Args:
        features_x: A num_examples x num_features matrix of features.
        features_y: A num_examples x num_features matrix of features.
        debiased: Use unbiased estimator of dot product similarity. CKA may still be
        biased. Note that this estimator may be negative.

    Returns:
        The value of CKA between X and Y.
    """
    features_x = features_x - np.mean(features_x, 0, keepdims=True)
    features_y = features_y - np.mean(features_y, 0, keepdims=True)

    dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
    normalization_x = np.linalg.norm(features_x.T.dot(features_x))
    normalization_y = np.linalg.norm(features_y.T.dot(features_y))

    if debiased:
        n = features_x.shape[0]
        # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
        sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
        sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
        squared_norm_x = np.sum(sum_squared_rows_x)
        squared_norm_y = np.sum(sum_squared_rows_y)

        dot_product_similarity = _debiased_dot_product_similarity_helper(
            dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
            squared_norm_x, squared_norm_y, n)
        normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
            squared_norm_x, squared_norm_x, n))
        normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
            squared_norm_y, squared_norm_y, n))

    return dot_product_similarity / (normalization_x * normalization_y)


def write_similarity_to_excel(similarity_data):
    pass

def main(rep_file, res_file):
    with open(rep_file, 'r') as reader:
        json_data = reader.read()
    representations = json.loads(json_data)
    with open(res_file, 'r') as reader:
        res = reader.readlines()
    res = [float(each.strip()) for each in res]
    # key: permutation name.
    # list:
    #   demos
    #   layers
    #   dims
    # print("repsentations = ", representations)
    # for each in representations.keys():
    #     print(each)
    rank_keys = list(representations.keys())
    sorted_rank_keys = [x for _, x in sorted(zip(res, rank_keys))]
    for each in sorted_rank_keys:
        print(each)
    print("res = ", res)
    demonstration_num = len(representations[rank_keys[0]])
    layers_num = len(representations[rank_keys[0]][0])
    dim_num = len(representations[rank_keys[0]][0][0][0])
    # 4 27 2400
    # 27 = 第一层输入加所有层输出。
    print(demonstration_num, layers_num, dim_num)

    # print(representations[rank_keys[0]][0][-1][0])

    # print(representations[rank_keys[-1]][0][-1][0])
    # demos * Similar(permutations * permutation) * layer_nums
    # cosine_similarity = 1 - distance.cosine(representations[rank_keys[0]][0][0][0], representations[rank_keys[-1]][0][0][0])
    # print("similarity_demo = ", cosine_similarity)

    # cosine_similarity = 1 - distance.cosine(representations[rank_keys[0]][0][-1][0], representations[rank_keys[-1]][0][-1][0])
    # print("similarity_demo = ", cosine_similarity)
    workbook = Workbook()
    default_sheet = workbook.active
    default_sheet.title = "Default Sheet"

    # big_difference = [[] for each in ]
    for k in range(layers_num):
        data = []
        big_difference_data = []
        cur_sheet = workbook.create_sheet(title="Layer " + str(k))
        for demo in range(demonstration_num):
            # 这里决定是算相邻的相似度还是按照表现排序的相似度
            # for i_key in sorted_rank_keys:
            for i_key in rank_keys:
                tmp_data = []
                for j_key in sorted_rank_keys:
                    cosine_similarity = 1 - distance.cosine(representations[i_key][demo][k][0], representations[j_key][demo][k][0]) 
                    tmp_data.append(cosine_similarity)
                    # print('{:.3f}'.format(cosine_similarity), end=' ')
                    # cka_similarity = feature_space_linear_cka(representations[i_key][demo][-1], representations[j_key][demo][-1]) 
                    # print(cka_similarity, end=' ')
                    # break
                data.append(tmp_data)
            big_difference_data.append(data[-1][0])
            data.append([''])
            data.append([''])
            data.append([''])
            data.append([''])
            data.append([''])
        for row in data:
            cur_sheet.append(row)
        default_sheet.append(big_difference_data)
        # for row in data:

    workbook.save(rep_file.replace('.json', '_adjacent.xlsx'))
        





    # 看最好的和最坏的




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # required arguments
    parser.add_argument('--rep_file', dest='rep_file', action='store', required=True, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--res_file', dest='res_file', action='store', required=True, help='name of model(s), e.g., GPT2-XL')

    args = parser.parse_args()
    args = vars(args)

    # simple processing
    # def convert_to_list(items, is_int=False):
    #     if is_int:
    #         return [int(s.strip()) for s in items.split(",")]
    #     else:
    #         return [s.strip() for s in items.split(",")]

    # args['models'] = convert_to_list(args['models'])
    # args['all_shots'] = convert_to_list(args['all_shots'], is_int=True)

    main(**args)