# Copyright (Modifications) 2024 NEAR paper authors
# Adapted from https://github.com/pym1024/SWAP/blob/main/correlation.py
# Licensed under the Academic Free License version 3.0

# Copyright 2024, Authors of "Sample-Wise Activation Patterns for Ultra-Fast NAS"
# Licensed under the Academic Free License version 3.0

import numpy as np
import torch
import torch.nn as nn

def count_parameters(model):
  return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6

def cal_regular_factor(model, mu, sigma):

    model_params = torch.as_tensor(count_parameters(model))
    regular_factor =  torch.exp(-(torch.pow((model_params-mu),2)/sigma))
   
    return regular_factor

def network_weight_gaussian_init(net: nn.Module):
    with torch.no_grad():
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            else:
                continue

    return net

class SampleWiseActivationPatterns(object):
    def __init__(self, device):
        self.swap = -1 
        self.activations = None
        self.device = device

    @torch.no_grad()
    def collect_activations(self, activations):
        n_sample = activations.size()[0]
        n_neuron = activations.size()[1]

        if self.activations is None:
            self.activations = torch.zeros(n_sample, n_neuron).to(self.device)  

        self.activations = torch.sign(activations)

    @torch.no_grad()
    def calSWAP(self, regular_factor):
        
        self.activations = self.activations.T # transpose the activation matrix: (samples, neurons) to (neurons, samples)
        self.swap = torch.unique(self.activations, dim=0).size(0)
        
        del self.activations
        self.activations = None
        torch.cuda.empty_cache()

        return self.swap * regular_factor


class SWAP:
    def __init__(self, model=None, inputs = None, device='cuda', seed=0, regular=False, mu=None, sigma=None):
        self.model = model
        self.interFeature = []
        self.seed = seed
        self.regular_factor = 1
        self.inputs = inputs
        self.device = device

        if regular and mu is not None and sigma is not None:
            self.regular_factor = cal_regular_factor(self.model, mu, sigma).item()

        self.reinit(self.model, self.seed)

    def reinit(self, model=None, seed=None):
        if model is not None:
            self.model = model
            self.register_hook(self.model)
            self.swap = SampleWiseActivationPatterns(self.device)

        if seed is not None and seed != self.seed:
            self.seed = seed
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
        del self.interFeature
        self.interFeature = []
        torch.cuda.empty_cache()

    def clear(self):
        self.swap = SampleWiseActivationPatterns(self.device)
        del self.interFeature
        self.interFeature = []
        torch.cuda.empty_cache()

    def register_hook(self, model):
        for n, m in model.named_modules():
            if isinstance(m, nn.ReLU):
                m.register_forward_hook(hook=self.hook_in_forward)

    def hook_in_forward(self, module, input, output):
        if isinstance(input, tuple) and len(input[0].size()) == 4:
            self.interFeature.append(output.detach()) 

    def forward(self):
        self.interFeature = []
        with torch.no_grad():
            self.model.forward(self.inputs.to(self.device))
            if len(self.interFeature) == 0: return
            activtions = torch.cat([f.view(self.inputs.size(0), -1) for f in self.interFeature], 1)         
            self.swap.collect_activations(activtions)
            
            return self.swap.calSWAP(self.regular_factor)


def compute_nas_score(model, train_dataloader, num_repeats, regular=False, mu=None, sigma=None):
    inputs, _ = next(iter(train_dataloader))
    swap = SWAP(model=model, inputs=inputs, device="cpu", seed=1337, regular=regular, mu=mu, sigma=sigma)
    swap_score = []
    for _ in range(num_repeats):
        network = model.apply(network_weight_gaussian_init)
        swap.reinit()
        swap_score.append(swap.forward())
        swap.clear()
    return np.mean(swap_score)
