


import torch
import torch.nn as nn

from methods.base import TTAMethod
from utils.registry import ADAPTATION_REGISTRY
from utils.losses import Entropy


@ADAPTATION_REGISTRY.register()
class Buffer(TTAMethod):
    """Tent adapts a model by entropy minimization during testing.
    Once tented, a model adapts itself by updating on every forward.
    """
    def __init__(self, cfg, model, num_classes):
        super().__init__(cfg, model, num_classes)

        # setup loss function
        self.softmax_entropy = Entropy()

    def loss_calculation(self, x):
        imgs_test = x[0]
        outputs = self.model(imgs_test)
        loss = self.softmax_entropy(outputs).mean(0)
        return outputs, loss

    @torch.enable_grad()
    def forward_and_adapt(self, x):
        """Forward and adapt model on batch of data.
        Measure entropy of the model prediction, take gradients, and update params.
        """
        if self.mixed_precision and self.device == "cuda":
            with torch.cuda.amp.autocast():
                outputs, loss = self.loss_calculation(x)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
        else:
            outputs, loss = self.loss_calculation(x)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        return outputs


    def forward_without_adapt(self, x):
        """Forward pass without any adaptation (used for source eval)."""
        with torch.no_grad():
            imgs_test = x
            outputs = self.model(imgs_test)
        return outputs


    def collect_params(model):
        """Collect the parameters from additional TTA layers and batch normalization layers.
        Updates both the `tta` layers and BatchNorm parameters while keeping others frozen.
        """
        params = []
        names = []

        
        #########################   Buffer with  all BN layers      ########################
        # for nm, m in model.named_modules():
        #     if  isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
        #         for np, p in m.named_parameters():
        #             if np in ['weight', 'bias']:  # weight is scale, bias is shift
        #                 params.append(p)
        #                 names.append(f"{nm}.{np}")

        # for name, param in model.named_parameters():
        #     if ('tta' in name) and param.requires_grad:
        #         params.append(param)
        #         names.append(name)     

       
        # ###############################  Buffer layers only    #####################
        for name, param in model.named_parameters():
            if ('tta' in name) and param.requires_grad:
                params.append(param)
                names.append(name)
        #######################################################################

        print(names, 'param_names')
        return params, names


    def configure_model(model):
        """Configure model for use with BufferTTA.
        Enables training mode and updates both `tta` layers and batch normalization layers.
        """
        model.train()
        # model.eval()
        model.requires_grad_(False)
        for m in model.modules():
            ########################################################################
            # if 'tta' in str(m) or isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
            ####################################################################
            if 'tta' in str(m) :
            #######################################################################
                m.requires_grad_(True)
                if isinstance(m, nn.BatchNorm2d):
                    m.track_running_stats = False
                    m.running_mean = None
                    m.running_var = None
        return model


