# Copyright (Modifications) 2024 NEAR paper authors
# Adapted from https://github.com/HamsterMimi/MeCo/blob/main/correlation/NAS_Bench_201.py
# Licensed under the MIT license

# Copyright 2023, Authors of "MeCo: Zero-Cost Proxy for NAS Via Minimum Eigenvalue of Correlation on Feature Maps"
# Use of this source code is governed by the MIT license
# that can be found in the LICENSE-MIT.txt file or at https://opensource.org/licenses/MIT.

import torch
import random

def get_score(net, x, device, measure='meco'):
    result_list = []

    def forward_hook(module, data_input, data_output):
        fea = data_output[0].clone().detach()
        n = torch.tensor(fea.shape[0])
        if n.item() == 1:
            # global pooling needs to be skipped
            return
        fea = fea.reshape(n, -1)
        if measure == 'meco':
            corr = torch.corrcoef(fea)
            corr[torch.isnan(corr)] = 0
            corr[torch.isinf(corr)] = 0
            values = torch.linalg.eig(corr)[0]
            result = torch.min(torch.real(values))
        elif measure == 'meco_opt':
            idxs = random.sample(range(n), 8)
            fea = fea[idxs, :]
            corr = torch.corrcoef(fea)
            corr[torch.isnan(corr)] = 0
            corr[torch.isinf(corr)] = 0
            values = torch.linalg.eig(corr)[0]
            result = torch.min(torch.real(values)) * n / 8
        result_list.append(result)
    for name, modules in net.named_modules():
        modules.register_forward_hook(forward_hook)
    x = x.to(device)
    net(x)
    results = torch.tensor(result_list)
    results = results[torch.logical_not(torch.isnan(results))]
    results = results[torch.logical_not(torch.isinf(results))]
    res = torch.sum(results)
    result_list.clear()

    return res.item()
