import torch
import torch.nn as nn
import os

from tqdm import tqdm
import numpy as np

from scipy.stats import spearmanr
import argparse

import sys
sys.path.append("./MusicTransformer_Pytorch")

from utilities.constants import SANITY_CHECK_LENGTH_256_TOTAL_5000


def read_nodes(path):
    numbers_0 = []
    with open(path, 'r') as file:
        for line in file:
            numbers_0.append(int(line.strip()))
    return numbers_0


def calculate_one(score_str, gt_str):


    pair_list = [(3,3)]

    print(score_str)
    print(gt_str)
    for drop_num, ind_num in pair_list:
            print("number of dropouts: ", drop_num)
            print("number of independent model: ", ind_num)

            if drop_num == 0:
                score = torch.load(score_str + f"/score_music_ensemble_1_independent_{ind_num}_retry.pt").T
            
            else:
                score = torch.load(score_str + f"/score_music_ensemble_{drop_num}_independent_{ind_num}_retry.pt").T
            
            nodes_str = []
            loss_str = []
            for i in range(50):
                nodes_str.append("../"+gt_str+f"/checkpoint/{i}/selected_indices_seed_{i}.txt")
                loss_str.append("../"+gt_str+f"/checkpoint/{i}/loss_test_generate_seed_0_length_1.pt")

            # full_nodes = [i for i in range(5000)]
            full_nodes = list(SANITY_CHECK_LENGTH_256_TOTAL_5000)

            node_list = []
            for node_str in nodes_str:
                numbers = read_nodes(node_str)
                index = []
                for number in numbers:
                    index.append(full_nodes.index(number))
                node_list.append(index)

            loss_list = []
            for loss_path in loss_str:
                loss_list.append(torch.load(loss_path, map_location=torch.device('cpu')).detach())

            approx_output = []
            for i in range(len(nodes_str)):
                score_approx_0 = score[node_list[i], :]
                sum_0 = torch.sum(score_approx_0, axis=0)
                approx_output.append(sum_0)


            res = 0
            # res_list = []
            counter = 0
            for i in range(178):
                tmp = spearmanr(np.array([approx_output[k][i] for k in range(len(approx_output))]), np.array([loss_list[k][i] for k in range(len(loss_list))])).statistic
                # print(np.array([approx_output[k][i] for k in range(len(approx_output))]))
                # print(np.array([loss_list[k][i] for k in range(len(loss_list))]))
                if np.isnan(tmp):
                    print("Numerical issue")
                    continue
                res += tmp
                counter += 1

            print("LDS: ", res/counter)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--score_dir", type=str, help="the directory storing scores")
    parser.add_argument("--gt_dir", type=str, help="the directory storing ground truths" )
    args = parser.parse_args()

    calculate_one(args.score_dir, args.gt_dir)