import torch
from scipy.stats import rankdata


def replace_module(model, old_module, new_module):
    # replace all instances of old_module with new_module
    for name, child in model.named_children():
        if isinstance(child, old_module):
            setattr(model, name, new_module())
        else:
            replace_module(child, old_module, new_module)


class PreactMonitor:
    def __init__(self, model, module):
        # record all pre-activations to any nn.Module of instance module
        self.model = model
        self.module = module
        self.vals = {}
        self._monitor_module(self.model)

    def reset_vals(self):
        for name in list(self.vals.keys()):
            self.vals[name] = torch.empty(0)

    def _monitor_module(self, model):
        for name, module in model.named_modules():
            if isinstance(module, self.module):
                self.vals[name] = torch.empty(0)
                module.register_forward_hook(self._record_preacts_gen(name))

    def _record_preacts_gen(self, hook_name):
        def _record_preacts(m, input, output):
            self.vals[hook_name] = torch.cat(
                (self.vals[hook_name], input[0].detach().cpu()), dim=0
            )

        return _record_preacts


def spearman_correlation(a, axis=0):
    a = np.asarray(a)
    a_ranked = np.apply_along_axis(rankdata, axis, a)
    rs = np.corrcoef(a_ranked, rowvar=axis)
    return rs
