# Copyright (Modifications) 2024 NEAR paper authors
# Adapted from https://github.com/idstcv/ZenNAS/blob/main/ZeroShotProxy/compute_zen_score.py
# Licensed under the Apache License, Version 2.0 (the "License");

# Copyright (C) 2010-2021 Alibaba Group Holding Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import torch
import torch.nn as nn
import numpy as np

def weights_init_xavier(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def weights_init_uniform(m):
    if isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def init_model(model):
    if model.weight_initialization == "xavier":
        model.apply(weights_init_xavier)
    elif model.weight_initialization == "uniform":
        model.apply(weights_init_uniform)
    elif model.weight_initialization == "default":
        for layer in model.children():
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()

def forward_pre_GAP(model, inputs):
    # for NATSBench
    if hasattr(model, "stem") and hasattr(model, "cells"):
        feature = model.stem(inputs)
        for i, cell in enumerate(model.cells):
            feature = cell(feature)
        out = model.lastact(feature)
    else:
        # for NASBench101
        for _, layer in enumerate(model.layers):
            inputs = layer(inputs)
        out = inputs
    return out

def compute_nas_score(gpu, model, mixup_gamma, resolution, batch_size, repeat, fp16=False):
    info = {}
    nas_score_list = []
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')

    if fp16:
        dtype = torch.half
    else:
        dtype = torch.float32

    with torch.no_grad():
        for repeat_count in range(repeat):
            init_model(model)
            input = torch.randn(size=[batch_size, resolution * resolution], device=device, dtype=dtype)
            input2 = torch.randn(size=[batch_size, resolution * resolution], device=device, dtype=dtype)

            mixup_input = input + mixup_gamma * input2
            output = forward_pre_GAP(model, input)
            mixup_output = forward_pre_GAP(model, mixup_input)

            nas_score = torch.sum(torch.abs(output - mixup_output))
            nas_score = torch.mean(nas_score)

            # compute BN scaling
            log_bn_scaling_factor = 0.0
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    bn_scaling_factor = torch.sqrt(torch.mean(m.running_var))
                    log_bn_scaling_factor += torch.log(bn_scaling_factor)
                pass
            pass
            nas_score = torch.log(nas_score) + log_bn_scaling_factor
            nas_score_list.append(float(nas_score))


    std_nas_score = np.std(nas_score_list)
    avg_precision = 1.96 * std_nas_score / np.sqrt(len(nas_score_list))
    avg_nas_score = np.mean(nas_score_list)


    info['avg_nas_score'] = float(avg_nas_score)
    info['std_nas_score'] = float(std_nas_score)
    info['avg_precision'] = float(avg_precision)
    return info
