import torch
import torch.nn as nn
import os

from tqdm import tqdm
import numpy as np

from scipy.stats import spearmanr

def read_nodes(path):
    numbers_0 = []
    with open(path, 'r') as file:
        for line in file:
            numbers_0.append(int(line.strip()))
    return numbers_0


def calculate_one():

    score = torch.load("./score/score_mnist_ensemble_10_independent_3_Q_True.pt").T  # _test_0225_regroup
    # score = torch.rand(5000, 500)
    print("score shape:", score.shape)

    nodes_str = []
    loss_str = []
    for i in range(50):
        nodes_str.append(f"./checkpoint/selected_indices_seed_{i}_sample_2500.txt")
        loss_str.append(f"./checkpoint/model_output_checkpoint_{i}_sample_2500_M.pt")

    full_nodes = [i for i in range(5000)]

    node_list = []
    for node_str in nodes_str:
        numbers = read_nodes(node_str)
        index = []
        for number in numbers:
            index.append(full_nodes.index(number))
        node_list.append(index)

    loss_list = []
    for loss_path in loss_str:
        loss_list.append(torch.load(loss_path, map_location=torch.device('cpu')).detach())

    approx_output = []
    for i in range(len(nodes_str)):
        score_approx_0 = score[node_list[i], :]
        sum_0 = torch.sum(score_approx_0, axis=0)
        approx_output.append(sum_0)

    print(len(loss_list), loss_list[0].shape)
    print(len(approx_output), approx_output[0].shape)

    res = 0
    counter = 0
    for i in range(500):
        tmp = spearmanr(np.array([approx_output[k][i] for k in range(len(approx_output))]),
                        np.array([loss_list[k][i].numpy() for k in range(len(loss_list))])).statistic
        if np.isnan(tmp):
            print("Numerical issue")
            continue
        res += tmp
        counter += 1

    print(counter)

    return res/counter, loss_list, approx_output


if __name__ == "__main__":
    res, _, _ = calculate_one()
    print(res)