import torch
import torch.nn as nn
import os

from tqdm import tqdm
import numpy as np

from scipy.stats import spearmanr

from torchvision import datasets, transforms
from model_train import get_cifar2_indices_and_adjust_labels
import argparse


def read_nodes(path):
    numbers_0 = []
    with open(path, 'r') as file:
        for line in file:
            numbers_0.append(int(line.strip()))
    return numbers_0


def calculate_one():

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="resnet9", help="model to train")
    args = parser.parse_args()


    score = torch.load("./score/score_cifar_dropout_ensemble_25_independent_25.pt").T
    print("score shape:", score.shape)

    nodes_str = []
    loss_str = []
    for i in range(50):
        nodes_str.append(f"./checkpoint/selected_indices_seed_{i}.txt")
        loss_str.append(f"./checkpoint/model_{args.model}_output_checkpoint_{i}.pt")

    # full_nodes = [i for i in range(5000)]
    # Load MNIST data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    cifar2_indices_train = get_cifar2_indices_and_adjust_labels(train_dataset)
    full_nodes = cifar2_indices_train[0:5000]

    node_list = []
    for node_str in nodes_str:
        numbers = read_nodes(node_str)
        index = []
        for number in numbers:
            index.append(full_nodes.index(number))
        node_list.append(index)

    loss_list = []
    for loss_path in loss_str:
        loss_list.append(torch.load(loss_path, map_location=torch.device('cpu')).detach())

    approx_output = []
    for i in range(len(nodes_str)):
        score_approx_0 = score[node_list[i], :]
        sum_0 = torch.sum(score_approx_0, axis=0)
        approx_output.append(sum_0)

    print(len(loss_list), loss_list[0].shape)
    print(len(approx_output), approx_output[0].shape)

    # print([float(loss_list[k][0].numpy()) for k in range(len(loss_list))])

    res = 0
    res_list = []
    counter = 0
    for i in range(500):
        tmp = spearmanr(np.array([approx_output[k][i] for k in range(len(approx_output))]),
                        np.array([loss_list[k][i].numpy() for k in range(len(loss_list))])).statistic
        
        if np.isnan(tmp):
            print("Numerical issue")
            continue
        res += tmp
        counter += 1
        res_list.append(tmp)

    # print(sorted(res_list))

    return res/counter, loss_list, approx_output


if __name__ == "__main__":
    res, _, _ = calculate_one()
    print(res)