"""
Builds upon: https://github.com/DequanWang/tent
Corresponding paper: https://arxiv.org/abs/2006.10726
"""
import torch
import torch.nn as nn

from methods.base import TTAMethod
from utils.registry import ADAPTATION_REGISTRY
from utils.losses import Entropy


@ADAPTATION_REGISTRY.register()
class Tent(TTAMethod):
    """Tent adapts a model by entropy minimization during testing.
    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, cfg, model, num_classes):
        super().__init__(cfg, model, num_classes)

        # setup loss function
        self.softmax_entropy = Entropy()

    def loss_calculation(self, x):
        imgs_test = x[0]
        outputs = self.model(imgs_test)
        loss = self.softmax_entropy(outputs).mean(0)
        return outputs, loss

    @torch.enable_grad()
    def forward_and_adapt(self, x):
        """Forward and adapt model on batch of data.
        Measure entropy of the model prediction, take gradients, and update params.
        """
        if self.mixed_precision and self.device == "cuda":
            with torch.cuda.amp.autocast():
                outputs, loss = self.loss_calculation(x)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
        else:
            outputs, loss = self.loss_calculation(x)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        return outputs
    
    def forward_without_adapt(self, x):
        """Forward pass without any adaptation (used for source eval)."""
        with torch.no_grad():
            imgs_test = x
            outputs = self.model(imgs_test)
        return outputs



    def collect_params(self):
        """Collect the affine scale + shift parameters from batch norms.

        Walk the model's modules and collect all batch normalization parameters.
        Return the parameters and their names.

        Note: other choices of parameterization are possible!
        """
        params = []
        names = []
        for nm, m in self.model.named_modules():
            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
                for np, p in m.named_parameters():
                    if np in ['weight', 'bias']:  # weight is scale, bias is shift
                        params.append(p)
                        names.append(f"{nm}.{np}")
        return params, names

    def configure_model(self):
        """Configure model for use with tent."""
        # train mode, because tent optimizes the model to minimize entropy
        # self.model.train()
        self.model.eval()  # eval mode to avoid stochastic depth in swin. test-time normalization is still applied
        # disable grad, to (re-)enable only what tent updates
        self.model.requires_grad_(False)
        # configure norm for tent updates: enable grad + force batch statisics
        for m in self.model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.requires_grad_(True)
                # force use of batch stats in train and eval modes
                m.track_running_stats = False
                m.running_mean = None
                m.running_var = None
            elif isinstance(m, nn.BatchNorm1d):
                m.train()   # always forcing train mode in bn1d will cause problems for single sample tta
                m.requires_grad_(True)
            elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
                m.requires_grad_(True)

#########################################################################################################################################################################
    # def collect_params(self):
    #     """Only collect from the first 1/3 of all normalization layers."""
    #     # 1) gather all (name, module) pairs that are normalization layers
    #     norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)
    #     all_norms = [
    #         (nm, m)
    #         for nm, m in self.model.named_modules()
    #         if isinstance(m, norm_types)
    #     ]

    #     # 2) decide how many to keep
    #     cutoff = len(all_norms) // 3

    #     params = []
    #     names = []
    #     # 3) only loop over the first third
    #     for nm, m in all_norms[:cutoff]:
    #         # print("included:", nm)
    #         for p_name, p in m.named_parameters():
    #             if p_name in ("weight", "bias") and p.requires_grad:
    #                 params.append(p)
    #                 names.append(f"{nm}.{p_name}")
    #     return params, names


    # def configure_model(self):
    #     """Only re‐enable and tweak the first 1/3 of normalization layers."""
    #     self.model.eval()
    #     self.model.requires_grad_(False)

    #     # gather all normalization modules (no names needed here)
    #     norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)
    #     all_norms = [m for m in self.model.modules() if isinstance(m, norm_types)]

    #     # compute cut‐off
    #     cutoff = len(all_norms) // 3

    #     # only touch the first third
    #     for m in all_norms[:cutoff]:
    #         # BatchNorm2d: force use batch‐stats and allow grad
    #         if isinstance(m, nn.BatchNorm2d):
    #             m.train()                  # use batch stats even in eval()
    #             m.track_running_stats = False
    #             m.running_mean = None
    #             m.running_var  = None
    #             m.requires_grad_(True)

    #         # BatchNorm1d: train mode avoids single‐sample issues
    #         elif isinstance(m, nn.BatchNorm1d):
    #             m.train()
    #             m.requires_grad_(True)

    #         # LayerNorm & GroupNorm: just turn gradients back on
    #         else:  # LayerNorm or GroupNorm
    #             m.requires_grad_(True)

################################################################################################################################################################################################################