# Adopted from GAIA official repository: https://github.com/JGEthanChen/GAIA-OOD

import torch
import torch.nn as nn
from typing import Any, Callable

class Grad_all_hook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.save_grad)
        self.data = torch.Tensor()

    def save_grad(self, module, input, output):
        def _stor_grad(grad):
            self.data = grad.detach()
        output.register_hook(_stor_grad)

    def close(self):
        self.hook.remove()

class Activation_all_hook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.save_activations)
        self.data = torch.Tensor()

    def save_activations(self, module, input, output):
        self.data = output.detach()

    def close(self):
        self.hook.remove()

class Grad_feature_hook:
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.save_grad)
        self.data = torch.Tensor()
        self.feature = torch.Tensor()

    def save_grad(self, module, input, output):
        def _stor_grad(grad):
            self.data = grad.detach()
        output.register_hook(_stor_grad)
        self.feature = output.clone()

    def close(self):
        self.hook.remove()

class GAIADetector:
    method_name: str
    model: Any
    detector: Any

    def __init__(self,method_name,model,hooks,device):
        self.method_name = method_name
        self.model = model
        self.hooks = hooks
        self.device = device

    @classmethod
    def create(cls,method_name,model_name,model,device):
        if method_name=='GAIA-Z':
            if model_name=='resnet34':
                hooks=cls.get_bn_hooks(model, model_name)
            else:
                hooks=cls.get_beforehead_hooks(model,model_name,cal_method = 'cal_zero')
        elif method_name=='GAIA-A':
                hooks=cls.get_beforehead_hooks(model,model_name,cal_method = 'cal_grad_value')

        return cls(method_name,model,hooks,device)
    
    def __call__(self,x):
        x = x.to(self.device)
        if self.method_name == 'GAIA-A':
            return self.cal_grad_value(self.model,x,device=self.device,hooks=self.hooks).cpu().numpy()    
        elif self.method_name == 'GAIA-Z':
            return self.cal_zero(self.model,x,device=self.device,hooks=self.hooks).cpu().numpy()    
        
    def end(self):
        for hook in self.hooks:
            hook.close()

    @staticmethod
    def get_conv_hooks(net):
        conv_hooks = []
        for module in net.modules():
            if isinstance(module, nn.Conv2d):
                conv_hooks.append(Grad_all_hook(module))
        return conv_hooks
    
    @staticmethod
    def get_bn_hooks(net, model_name):
        bn_hooks = []
        if model_name == 'BiT-S-R101x1':
            cnt = 0
            for module in net.body.block4.modules():
                if isinstance(module, nn.GroupNorm):
                    bn_hooks.append(Grad_all_hook(module))
            print("model bn hook length", len(bn_hooks))
        else:
            for module in net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_hooks.append(Grad_all_hook(module))
        return bn_hooks

    @staticmethod
    def get_beforehead_hooks(net, model_name, cal_method='', dataset=''):
        beforehead_hooks = []
        module_list = []
        if model_name in['resnet34', 'resnet18', 'resnet50']:
            for module in net.layer3.modules():
                if isinstance(module, nn.BatchNorm2d):
                    module_list.append(module)
            for module in net.layer4.modules():
                if isinstance(module, nn.BatchNorm2d):
                    module_list.append(module)
            module_list.append(net.layer4)
        elif model_name in ['wrn_40_2']:
            for module in net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    module_list.append(module)
            module_list.append(net.AdaptAvgPool)
        elif model_name in ['vgg', 'vgg16']:
            for module in net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    module_list.append(module)
            module_list.append(net.pool4)
        elif model_name in ['BiT-S-R101x1','BiT-M-R101x1','BiT-M-R152x2']:
            for module in net.body.block4.modules():
                if isinstance(module, nn.GroupNorm):
                    module_list.append(module)
            if cal_method == 'cal_grad_value' or dataset != 'textures':
                module_list.append(net.before_head)
            else:
                module_list.append(net.before_head.gn)
            
        elif model_name in ['BiT-M-R50x1', 'BiT-S-R50x1']:
            for module in net.body.modules():
                if isinstance(module, nn.GroupNorm):
                    module_list.append(module)
            module_list.append(net.before_head.gn)
            module_list.append(net.before_head)
        for index, module in enumerate(module_list):
            if index < len(module_list)-1:
                beforehead_hooks.append(Grad_all_hook(module))
            else:
                beforehead_hooks.append(Grad_feature_hook(module))
        return beforehead_hooks

    @staticmethod
    def get_square(gradients):
        gradients = torch.cat(gradients, dim=1)
        gradients = torch.pow(gradients, 2)
        var = gradients.mean(dim=(-1))
        return var


    def cal_zero(self,net, input, device=None, hooks=None):
        net.zero_grad()
        y = net(input)
        y.max(dim=1).values.sum().backward()
        gradients = [hook.data for hook in hooks]
        gradients = [torch.where(grad != 0, torch.ones_like(grad), torch.zeros_like(grad)) for grad in gradients]
        scores = [grad.mean(dim=(-1, -2)) for grad in gradients]
        square_scores = self.get_square(scores)
        return square_scores
    
    @staticmethod
    def cal_grad_value(net, input, device, hooks=None):
        net.zero_grad()
        y = net(input)
        logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda()
        
        loss = logsoftmax(y)
        loss.sum().backward(retain_graph=True)
        before_head_grad = hooks[-1].data.mean(dim=(-1, -2))
        output_component = torch.sqrt(torch.abs(before_head_grad).mean(dim=1))
        output_component = output_component.unsqueeze(dim=1)

        loss = net.before_head_data
        loss.sum().backward()
        gradients = [hook.data for hook in hooks]
        gradients = gradients[:-1]
        gradients = [grad.mean(dim=(-1, -2)) for grad in gradients]
        inner_component = torch.abs(torch.cat(gradients, dim=1))
        score = torch.pow(inner_component / output_component, 2).mean(dim=1)
        return score.detach()

    
    