import torch
import numpy as np
from scipy.spatial.distance import jensenshannon
import argparse
from utils.model import ModelWrapper
from utils.load_data import load_json_data
from utils.metrics import draw_bar
import pandas as pd
from tqdm import tqdm
# Step 2: 定义计算Jensen-Shannon散度的函数
def get_token_probabilities(text, tokenizer, model):
    """
    获取文本中每个词的概率分布
    """
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs["input_ids"])
        logits = outputs.logits  # 模型输出的logits
    # 对 logits 进行 softmax 操作以获取概率
    probabilities = torch.softmax(logits, dim=-1).squeeze()
    token_ids = inputs["input_ids"].squeeze()

    # 词汇概率分布
    token_probs = {}
    for idx, token_id in enumerate(token_ids):
        token = tokenizer.decode(token_id)
        prob = probabilities[idx, token_id].item()
        if token in token_probs:
            token_probs[token] += prob
        else:
            token_probs[token] = prob

    # 归一化概率
    total = sum(token_probs.values())
    for token in token_probs:
        token_probs[token] /= total

    return token_probs

def align_distributions(distributions):
    """
    对齐多个概率分布到相同的全局词汇表
    """
    vocab = set()
    for dist in distributions:
        vocab.update(dist.keys())
    vocab = sorted(vocab)

    aligned_dists = []
    for dist in distributions:
        aligned_dist = np.array([dist.get(word, 0) for word in vocab])
        aligned_dists.append(aligned_dist)
    return aligned_dists

def compute_js_divergence_matrix(texts, tokenizer, model):
    """
    计算多个文本的JS散度矩阵
    """
    # 1. 提取每个文本的概率分布
    distributions = [get_token_probabilities(text, tokenizer, model) for text in texts]
    
    # 2. 词汇对齐
    aligned_dists = align_distributions(distributions)
    
    # 3. 计算JS散度矩阵
    n = len(aligned_dists)
    js_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i, n):
            js_div = jensenshannon(aligned_dists[i], aligned_dists[j], base=2) ** 2
            js_matrix[i, j] = js_div
            js_matrix[j, i] = js_div  # 对称矩阵
    return js_matrix

# def split_difficulty(dataset):
#     difficulty_dic = {}
#     sc_path = f'./result/{dataset}/{model_name}/sc10_e3_{n_samples}.json'
#     sc_result = load_json_data(sc_path)[:-1]
#     for item in sc_result:
#         id = item['id']
#         difficulty = 6 - item['corrects'].count(True) // 2
#         if difficulty in difficulty_dic.keys():
#             difficulty_dic[difficulty].append(id)
#         else:
#             difficulty_dic[difficulty] = [id]
#     return difficulty_dic   



# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--model', type=str, default='Llama3_1_8b_chat')
#     parser.add_argument('--n_samples', type=int, default=500)
#     parser.add_argument('--roll_num', type=int, default=10)
#     parser.add_argument('--type', type=str, default='js_div')
#     args = parser.parse_args()
#     # Step 1: 加载预训练的模型和分词器
#     model_name = args.model
#     n_samples = args.n_samples
#     roll_num = args.roll_num
#     type = args.type
    
#     model_wrapper = ModelWrapper(model_name)
#     model = model_wrapper.model
#     tokenizer = model_wrapper.tokenizer

#     datasets = ['gsm8k', 'aqua', 'math', 'siqa', 'proofwriter']
#     dataset_ls = []
#     difficulty_ls = []
#     scores = []
#     for dataset in datasets:
#         difficulty_dic = split_difficulty(dataset)
#         for difficulty, index in difficulty_dic.items():
#             sc_path = f'./result/{dataset}/{model_name}/sc{roll_num}_e3_{n_samples}.json'
#             sc_result = load_json_data(sc_path)[:-1]
#             result = [item for item in sc_result if item['id'] in index]
#             for item in tqdm(result):
#                 if type == 'js_div':
#                     responses = item['response']
#                     js_divergences = compute_js_divergence_matrix(responses)
#                     score = np.mean(np.array(js_divergences))
#                 elif type == 'none':
#                     answers = item['answer']
#                     score = answers.count(None)
#                 elif type == 'length':
#                     responses = item['response']
#                     score = np.mean(np.array([len(res) for res in responses]))
#                 else:
#                     answers = item['answer']
#                     score = len(set(answers))
#                 difficulty_ls.append(difficulty)
#                 dataset_ls.append(dataset)
#                 scores.append(score)
#                 path = f'fig/{model_name}_{type}.png'
#                 data = {'difficulty':difficulty_ls, 'js_div':scores, 'dataset':dataset_ls}
#                 data = pd.DataFrame(data, columns=['difficulty', 'js_div', 'dataset'])
#     draw_bar(data, path)