# %%
import torch 
import os 
import json 
from ming.utils import client

# %%
print(1)

# %%
# 读取logs/diverse/math_500/ming1.8b-chat-1024.jsonl
# 读取logs/attentions/math_500/ming1.8b-chat-1024_entropy.csv
# print(os.path.abspath(os.curdir))
baseline_output_path = '../../logs/diverse/math_500/ming1.8b-chat-1024.jsonl'
baseline_entropy_path = '../../logs/attentions/math_500/ming1.8b-chat-1024_entropy.csv'
baseline_output = client.read_jsonl(baseline_output_path)
baseline_entropy = client.read_csv(baseline_entropy_path)

# %%
# 读取logs/diverse/math_500/ming1.8b-molora-4x1-topk-openmath01-womolora.jsonl
# 读取logs/attentions/math_500/ming1.8b-molora-4x1-topk-openmath01_entropy.csv
womolora_output_path = '../../logs/diverse/math_500/ming1.8b-molora-4x1-topk-openmath01-womolora.jsonl'
womolora_entropy_path = '../../logs/attentions/math_500/ming1.8b-molora-4x1-topk-openmath01-womolora_entropy.csv'
womolora_output = client.read_jsonl(womolora_output_path)
womolora_entropy = client.read_csv(womolora_entropy_path)

# %%
molora_output_path = '../../logs/diverse/math_500/ming1.8b-molora-4x1-topk-openmath01.jsonl'
molora_entropy_path = '../../logs/attentions/math_500/ming1.8b-molora-4x1-topk-openmath01_entropy.csv'
molora_output = client.read_jsonl(molora_output_path)
molora_entropy = client.read_csv(molora_entropy_path)

# %%
from ming.eval.eval_em import math_acc 
from ming.eval.eval_em import bbh_acc
# %%
from typing import List, Dict, Optional, Tuple
import numpy as np
def obtain_pair(output: List[str], entropy: List[str]):
    assert len(output) == len(entropy) - 3
    res = []
    for i in range(len(output)):
        # obtain the ith output's accuracy
        acc = math_acc(output[i])
        if acc is None:
            math_acc(output[i])
            print(output[i])
        # obtain the ith last layer entropy
        # etp = np.mean([float(x) for x in entropy[i + 1].split("\t")])
        etp = float(entropy[i + 1].split("\t")[23])
        res.append((acc, etp))
    return res 

baseline_pair = obtain_pair(baseline_output, baseline_entropy)
womolora_pair = obtain_pair(womolora_output, womolora_entropy)
molora_pair = obtain_pair(molora_output, molora_entropy)


# %%

# obtain womolora-baseline pair
womolora_baseline_subtract_pair = [(womolora_pair[i][0] - baseline_pair[i][0], womolora_pair[i][1] - baseline_pair[i][1]) for i in range(len(womolora_pair)) if (womolora_pair[i][0] != 0 and baseline_pair[i][0] == 0) or (baseline_pair[i][0] != 0 and womolora_pair[i][0] == 0)]
# obtain molora-baseline pair
molora_baseline_subtract_pair = [(molora_pair[i][0] - baseline_pair[i][0], molora_pair[i][1] - baseline_pair[i][1]) for i in range(len(molora_pair)) if (molora_pair[i][0] != 0 and baseline_pair[i][0] == 0) or (baseline_pair[i][0] != 0 and molora_pair[i][0] == 0)]
# calculate the pearson and spearman correlation between the entropy and accuracy
import scipy
from scipy.stats import pearsonr, spearmanr
womolora_baseline_acc = [i[0] for i in womolora_baseline_subtract_pair]
womolora_baseline_etp = [i[1] for i in womolora_baseline_subtract_pair]
molora_baseline_acc = [i[0] for i in molora_baseline_subtract_pair]
molora_baseline_etp = [i[1] for i in molora_baseline_subtract_pair]
womolora_baseline_acc_etp = pearsonr(womolora_baseline_acc, womolora_baseline_etp)
molora_baseline_acc_etp = pearsonr(molora_baseline_acc, molora_baseline_etp)
print(womolora_baseline_acc_etp)
print(molora_baseline_acc_etp)
# compute spearman correlation
womolora_baseline_acc_etp = spearmanr(womolora_baseline_acc, womolora_baseline_etp)
print(womolora_baseline_acc_etp)
molora_baseline_acc_etp = spearmanr(molora_baseline_acc, molora_baseline_etp)
print(molora_baseline_acc_etp)


# %%
womolora_molora_subtract_pair = [(womolora_pair[i][0] - molora_pair[i][0], womolora_pair[i][1] - molora_pair[i][1]) for i in range(len(womolora_pair)) if (womolora_pair[i][0] != 0 and molora_pair[i][0] == 0) or (molora_pair[i][0] != 0 and womolora_pair[i][0] == 0)]
womolora_molora_acc = [i[0] for i in womolora_molora_subtract_pair]
womolora_molora_etp = [np.sign(i[1]) for i in womolora_molora_subtract_pair]
print(len(womolora_molora_acc))
print(len(womolora_molora_etp))
print(womolora_molora_acc)
print(womolora_molora_etp)

# print(womolora_molora_etp)
womolora_molora_acc_etp = pearsonr(womolora_molora_acc, womolora_molora_etp)
print(womolora_molora_acc_etp)
womolora_molora_acc_etp = spearmanr(womolora_molora_acc, womolora_molora_etp)
print(womolora_molora_acc_etp)
# %%
spearmanr(np.ones(80) + np.random.randn(80) * 0.0001,
          np.ones(80) * -1 + np.random.randn(80) * 0.0001)
# %%

# from womolora_pair and molora_pair select the pair whose etp (the second item) is higher
# if womolora_pair[i][1] is higher, the result append womolora_pair[i]
# otherwise, append molora_pair[i]
best_among_womolora_molora = [womolora_pair[i] if womolora_pair[i][1] > molora_pair[i][1] else molora_pair[i] for i in range(len(womolora_pair))]
# compute acc
best_acc = [i[0] for i in best_among_womolora_molora]
print(np.mean(best_acc))
print(np.mean([i[0] for i in womolora_pair]))
# %%
