#!/usr/bin/env python3

from src import const
from torch import nn
import torchvision
import torch


class Model(nn.Module):
    def __init__(self, is_contrastive=True, multilabel=False, return_logits=False, logits_only=False,
                 register_backward_hook=False, load_pretrained_weights=None, n_classes=None, upsampling_level=None,
                 device=None, segmentation_threshold=None):
        super().__init__()
        self.populate_with_defaults(locals())

        self.backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2 if load_pretrained_weights else None)

        if self.upsampling_level >= 1 or self.upsampling_level <= -5:
            self.backbone.layer4[0].conv2.stride = (1, 1)
            self.backbone.layer4[0].downsample[0].stride = (1, 1)
        if self.upsampling_level >= 2 or self.upsampling_level <= -4:
            self.backbone.layer3[0].conv2.stride = (1, 1)
            self.backbone.layer3[0].downsample[0].stride = (1, 1)
        if self.upsampling_level >= 3 or self.upsampling_level <= -3:
            self.backbone.layer2[0].conv2.stride = (1, 1)
            self.backbone.layer2[0].downsample[0].stride = (1, 1)
        if self.upsampling_level >= 4 or self.upsampling_level <= -2:
            self.backbone.conv1.stride = (1, 1)
        if self.upsampling_level >= 5 or self.upsampling_level <= -1:
            self.backbone.maxpool.stride = 1

        if is_contrastive:
            self.backbone.layer4[-1].bn3 = nn.Identity()
            self.backbone.layer4[-1].relu = nn.Identity()
        self.backbone.layer4[-1].conv3.register_forward_hook(self._hook)

        self.linear = nn.Linear(2048, self.n_classes, bias=not is_contrastive)
        self.probabilities = nn.Identity() if self.return_logits else nn.Softmax(dim=1) if not self.multilabel else nn.Sigmoid()

        if const.DATASET in ['imagenet', 'salientimagenet'] and self.load_pretrained_weights:
            self.linear.weight = self.backbone.fc.weight
            if not is_contrastive: self.linear.bias = self.backbone.fc.bias
        self.backbone.fc = nn.Identity()

        self.to(self.device)
        if not self.logits_only: self.initialize_and_verify()

    def populate_with_defaults(self, kwargs):
        for key in kwargs:
            if kwargs[key] is None: setattr(self, key, getattr(const, key.upper()))
            elif key not in ['self', '__class__']: setattr(self, key, kwargs[key])

    def initialize_and_verify(self):
        with torch.no_grad():
            x = torch.randn(100, *const.IMAGE_SHAPE, device=self.device)
            logits, cam = self(x)
            cam_logits = cam.view(*cam.shape[:2], -1).sum(2)

            if not self.is_contrastive: cam_logits -= self.linear.bias
            print('Approx. cam logit err bound:', (logits - cam_logits).abs().max().item())

            if self.is_contrastive: assert torch.allclose(logits, cam_logits, atol=1E-5 if torch.get_float32_matmul_precision() == 'highest' else 1E-2)

    def _hook(self, model, i, o):
        def assign(grad):
            self.feature_grad = grad
        self.feature_rect = o
        if self.register_backward_hook: o.register_hook(assign)

    @torch.compiler.disable
    def forward(self, x):
        x = self.backbone(x)
        logits = self.linear(x)

        if self.logits_only: return logits
        return logits if self.training else self.probabilities(logits), self._bp_free_hi_res_cams()

    def get_semantic_map(self, cams):
        ohe = torch.zeros(cams.shape[:2], device=cams.device)
        ohe[:, 0] = 1.  # choice of index does not matter
        cams = self.get_contrastive_cams(ohe, cams)

        cc_mink = (-cams).topk(k=2, axis=1)
        return ((cc_mink.values[:, 0] - cc_mink.values[:, 1]) > 1e-3).to(torch.uint8) * (cc_mink.indices[:, 0] + 1).to(torch.uint8)

    def get_recon_cams(self, cams):
        return cams - cams.mean(1, keepdim=True)

    def get_contrastive_cams(self, y, cams):
        return torch.index_select(cams.view(-1, *cams.shape[2:]), 0, y.argmax(1) + (torch.arange(cams.size(0), device=self.device) * cams.size(1))).repeat(1, cams.size(1), 1).view(*cams.shape) - cams

    def _bp_free_hi_res_cams(self):  # required to obtain gradients on self.linear.weight
        return (self.linear.weight @ self.feature_rect.flatten(2)).unflatten(2, self.feature_rect.shape[2:]) / self.feature_rect.shape[-1]**2

    def _hi_res_cams(self, logits):  # inefficient but more general; not restricted to single dense layer
        cams = torch.zeros(*logits.shape, *self.feature_rect.shape[2:], device=self.device)
        for img_idx in range(logits.shape[0]):
            for class_idx in range(logits.shape[1]):
                logits[img_idx, class_idx].backward(retain_graph=True, inputs=self.backbone.layer4[-1].conv3.weight)
                cams[img_idx, class_idx] = (self.feature_rect * self.feature_grad).sum(dim=1)[img_idx]

        self.feature_grad = None
        self.feature_rect = None

        return cams
