# import json_repair
# # 示例文本
# text = '''Qilin offers a comprehensive collection of user sessions with heterogeneous results like image-text notes, video notes, commercial notes, and direct answers.
# ```json
# { "task": "xxx" }
# ```
# '''
# decoded_object = json_repair.repair_json(text, return_objects=True)
# print(type(decoded_object))
import math
import json
import os
import sys
from collections import defaultdict
from utils import *
from tqdm import tqdm
from metric import Reward
import random
from scipy.stats import spearmanr, kendalltau
from concurrent.futures import ProcessPoolExecutor


    # 定义计算斯皮尔曼系数的并行函数
def compute_pairwise_distance(i, j, rankings):
    correlation, _ = spearmanr(rankings[i], rankings[j])
    dist = 1 - correlation  # 将斯皮尔曼系数转化为距离
    return i, j, dist


def compute_spearman_distance(rankings):
    """
    计算每对排序之间的斯皮尔曼等级相关系数，并将其转化为距离矩阵。
    使用并行计算加速计算过程。
    
    参数:
        rankings: 一个 2D numpy 数组，形状为 (num_people, num_models)，
                  每一行表示一个人的对模型的评分。
    
    返回:
        distance_matrix: 计算得到的距离矩阵。
    """
    num_people = rankings.shape[0]
    distance_matrix = np.zeros((num_people, num_people))
    max_workers = 50
    # 使用 ProcessPoolExecutor 并行计算所有人对之间的斯皮尔曼系数
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for i in tqdm(range(num_people)):
            for j in range(i + 1, num_people):
                futures.append(executor.submit(compute_pairwise_distance, i, j, rankings))

        for future in tqdm(futures):
            i, j, dist = future.result()
            distance_matrix[i, j] = dist
            distance_matrix[j, i] = dist  # 对称赋值

    return distance_matrix

def mode(arr):
    c = defaultdict(int)
    for e in arr:
        c[e] += 1
    sorted_data = sorted(c.items(), key=lambda x: (x[1], x[0]), reverse=True)
    return sorted_data[0][0]


def merge(output_dir, output_file):
    data = []
    for file in tqdm(os.listdir(output_dir)):
        if not file.endswith('.json'):
            continue
        with open(os.path.join(output_dir, file), 'r') as f2:
            tmp = json.load(f2)
            if 'task_id' not in tmp:
                print(file)
                print(os.path.join(output_dir, file))
            data.append(tmp)

    if output_file is not None:
        json.dump(data, open(output_file, 'w'), indent=4)
    return data


def merge_files(file_list):
    # data = [
    #    json.load(open(os.path.join(output_dir, file))) for file in os.listdir(output_dir)
    # ]
    data = []
    for file in file_list:
        tmp = json.load(open(file))
        cnt = 0
        for line in tmp:
            cnt += Reward.reward_fn(line['type'], line['content'], line['solution'])
            if 'query_' in line['task_id']:
                line['task_id'] = line['task_id'].replace('query_', line['type'])
        print(file, '===>', cnt / len(tmp))
        data += tmp

    table = defaultdict(lambda: defaultdict(list))

    tasks = {}
    for item in data:
        table[item['task_id']][item['role']].append(item)
        tasks[item['task_id']] = item
    
    print(len(tasks))

    model_size = max([len(v) for v in table.values()])
    print(model_size)
    table = {k: v for k, v in table.items() if len(v) == model_size}   
    tasks = {k: v for k, v in tasks.items() if k in table}

    return table, tasks 

# def get_ROI(line, is_weighted=False):
#     r = Reward.reward_fn(line['type'], line['content'], line['solution'])
#     usage = (line['usage'][0] + 4 * line['usage'][1])  / 1000
#     # print(usage)
#     if is_weighted:
#         usage *= DOCUMENTATIONS[line['role']]['cost'] 
#     return r / usage

def get_ROI(line, is_weighted=False):
    r = Reward.reward_fn(line['type'], line['content'], line['solution'])
    usage = (line['usage'][0] + 1 * line['usage'][1])  / 1000
    if is_weighted:
        speed =DOCUMENTATIONS[line['role']]['speed'] 
        if len(speed) == 0:
            raise ValueError("speed record is empty")
        avg_speed = sum([x[1] / (x[0] / 1000) for x in speed]) / len(speed)
        usage = 0.5 * usage * avg_speed
    return r / usage


from sklearn.mixture import GaussianMixture
def gmm_clustering(distance_matrix):
    """
    使用高斯混合模型（GMM）进行聚类，并通过BIC值选择最佳的聚类数。
    
    参数:
        distance_matrix: 输入的距离矩阵（shape = num_samples x num_samples）
    
    返回:
        labels: 每个样本所属的簇的标签（shape = num_samples,）
    """
    # 使用 BIC 来选择最佳的聚类数
    bic_scores = []
    max_k = 20
    for n_components in range(1, max_k):  # 尝试不同的簇数，范围可以根据实际需求调整
        gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=42)
        gmm.fit(distance_matrix)
        bic_scores.append(gmm.bic(distance_matrix))  # 计算BIC
    
    # 绘制 BIC 曲线
    plt.plot(range(1, max_k), bic_scores, marker='o')
    plt.xlabel('Number of clusters')
    plt.ylabel('BIC')
    plt.title('BIC for GMM')
    plt.savefig('GMM.png')
    plt.show()


    # 选择 BIC 最小的簇数
    optimal_k = np.argmin(bic_scores) + 1  # 获取 BIC 最小的簇数（加 1 是因为簇数从1开始）

    # 使用 GMM 聚类，拟合最佳的簇数
    gmm = GaussianMixture(n_components=optimal_k, covariance_type='full', random_state=42)
    labels = gmm.fit_predict(distance_matrix)  # 获取每个样本的标签
    centers = gmm.means_

    return labels, centers

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

def kmeans(distance_matrix):
    # 使用肘部法则来确定最佳簇数
    print('使用肘部法则来确定最佳簇数')
    sse = []
    max_k = 20
    for k in tqdm(range(1, max_k)):
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans.fit(distance_matrix)
        sse.append(kmeans.inertia_)

    # 绘制肘部法则图
    plt.plot(range(1, max_k), sse, marker='o')
    plt.xlabel('Number of clusters')
    plt.ylabel('SSE')
    plt.title('Elbow Method for Optimal K')
    plt.savefig('kmeans.png')
    plt.show()
    
    k = distance_matrix.shape[0] // 50
    optimal_k = 10 # 你可以从肘部图中选择最佳簇数，例如此处假设选择了5

    # 使用 K-means 进行最终的聚类
    kmeans = KMeans(n_clusters=optimal_k, random_state=42)
    kmeans.fit(distance_matrix)

    # 获取每个样本的聚类标签
    labels = kmeans.labels_

    # 获取每个簇的中心点（均值向量）
    centers = kmeans.cluster_centers_

    # 打印每个簇的标签和中心点
    print(f"Cluster labels: {labels}")
    print(f"Cluster centers (mean vectors): {centers}")

    return labels, centers

def patch(cache, tasks, batch_size=50, left=4, right=10, divide='random'):
    table = defaultdict(lambda: defaultdict(float))
    ROI = defaultdict(lambda: defaultdict(float))

    role = set()
    for task_id, model_outputs in cache.items():
        for name, outputs in model_outputs.items():
            role.add(name)
            s = [Reward.reward_fn(line['type'], line['content'], line['solution']) for line in outputs]
            table[task_id][name] = mode(s) / 100

            ROI[task_id][name] = mode([get_ROI(line, True) for line in outputs])
            # print(name, output['content'])
    
    print(len(role))
    print(len(table))

    # difficulty relabling
    difficulty = defaultdict(list)

    # print(' '.join([r.rjust(30) for r in role]))
    for task_id, scores in table.items():
        k = len(role) - sum(scores.values()) # 有多少人做不出来
        text = [task_id.rjust(30)] + [f'{scores[name]}'.ljust(10) for name in role] + [str(k)]
        difficulty[k].append(task_id)

    _tasks = []
    for task_id, item in tasks.items():
        k = [k for k, arr in difficulty.items() if task_id in arr][0]
        if k!=len(role) and k!=0 and right >= k >= left:
            item.pop('role')
            item.pop('raw')
            item.pop('content')
            item.pop('usage')
            item.pop('end')
            item['difficulty'] = k
            _tasks.append(item)
            print(np.var(list(ROI[task_id].values())), ROI[task_id])
            print()

    def chunks(data, batch_size):
        """将数据分块，每个块包含 batch_size 个元素"""
        for i in range(0, len(data), batch_size):
            yield data[i: i + batch_size]

    # 将每个人的评分转换为排序
    def rank_scores(scores):
        return np.argsort(np.argsort(scores))

    random.shuffle(_tasks)
    random.shuffle(_tasks)
    random.shuffle(_tasks)
    random.shuffle(_tasks)

    if divide == 'random':
        results = []
        batch_difficulty = defaultdict(int)
        for batch in chunks(_tasks, batch_size):
            if len(batch) != batch_size:
                continue
            results.append({
                "batch_id": len(results),
                "tasks": batch
            })
            d = int(sum([task['difficulty'] for task in batch]) / len(batch))
            batch_difficulty[d] += 1
        print(batch_difficulty)

    elif divide == 'rank':
        rankings = np.array([rank_scores([ROI[line['task_id']][r] for r in role]) for line in _tasks])
        # rankings = rankings[:500, :]
        print(rankings.shape)
        distance = compute_spearman_distance(rankings)
        labels, centers = kmeans(distance)
        # labels, centers = gmm_clustering(distance)

        print(role)
        data = defaultdict(list)
        for ids, line in zip(labels, _tasks):
            data[ids].append(line)
        results = []
        for ids, lines in data.items():
            print('=' * 20, ids, '=' * 20)
            random.shuffle(lines)
            for i in range(0, len(lines), batch_size):
                print()
                # if i + batch_size > len(lines):
                    # print(f"ignore last {len(lines) - i}")
                    # continue
                for line in lines:
                    print(rank_scores([ROI[line['task_id']][r] for r in role]))
                results.append({
                    "batch_id": len(results),
                    "tasks": lines[i:i+batch_size]
                })

    return results


batch_size = 50
right = 11
left = 3

if  __name__ == '__main__':

    # data = []
    # for folder in [
    #         './Llama-3.1-70B-Instruct_qa_v1',
    #         ]:
    #     # file = folder.replace('./','') + '.json'
    #     data.extend(merge(folder, None))
    # print(len(data))
    # key = 'f1'

    # reward = round(sum([line[key] for line in data]) / len(data), 2)
    # print(reward)
    # file = f'./results/Llama-3.1-70B-Instruct_musique_qa_len={len(data)}.{key}={reward}.version=v1.json'
    # print(file)
    # json.dump(data, open(file, 'w'), indent=4)
    # exit()
    
    file_list = []
    for name, model_files in OUTPUT_CONFIG.items():
        for domain, files in model_files.items():
            file_list += [file for file in files if file!='']
            for file in files:
                tmp = json.load(open(file))
                task_ids = set([line['task_id'] for line in tmp])
                try:
                    assert len(tmp) == len(task_ids)
                except AssertionError:
                    print(file, len(tmp), len(task_ids))
    # exit()
    cache, tasks = merge_files(file_list)
    file = f'./cache_tasks.{len(cache)}.json'
    # json.dump(cache, open(file, 'w'), indent=4)
    print(file)

    tasks = patch(cache, tasks, batch_size=batch_size, left=left, right=right, divide='rank')
    file = f'./batch_tasks.{len(tasks)}.batch={batch_size}.left={left}.right={right}.newcost.random.json'
    json.dump(tasks, open(file, 'w'), indent=4)
    print(file)
