import torch
import torch.nn as nn
from src.adapter.base_adapter import BaseAdapter
from src.utils.GaussianClusterMemory import GaussianClusterMemory 
from src.utils.loss_func import softmax_entropy
from src.utils.bn_layers import RobustBN1d, RobustBN2d
from src.utils import set_named_submodule, get_named_submodule
from src.utils.custom_transforms import get_tta_transforms

class RoTTA_C2FTTA(BaseAdapter):
    def __init__(self, cfg, model, optimizer):
        super(RoTTA_C2FTTA, self).__init__(cfg, model, optimizer)

        self.mem = GaussianClusterMemory(
            capacity=cfg.ADAPTER.C2FTTA.STMEM_CAPACITY,
            num_class=cfg.CORRUPTION.NUM_CLASS,
            lambda_t=getattr(cfg.ADAPTER.RoTTA, "LAMBDA_T", 1.0),
            lambda_u=getattr(cfg.ADAPTER.RoTTA, "LAMBDA_U", 1.0),
            lambda_d=getattr(cfg.ADAPTER.C2FTTA, "LAMBDA_D", 1.0),
            max_bank_num=cfg.ADAPTER.C2FTTA.STMEM_MAX_CLUS,
            base_threshold=cfg.ADAPTER.C2FTTA.BASE_THRESHOLD,
        )
        print(f"[rottac2f] mem(capacity={cfg.ADAPTER.C2FTTA.STMEM_CAPACITY}, "
              f"max_bank_num={cfg.ADAPTER.C2FTTA.STMEM_MAX_CLUS}, "
              f"topk={cfg.ADAPTER.C2FTTA.STMEM_TOPK_CLUS})")

        self.model_ema = self.build_ema(self.model)
        self.transform = get_tta_transforms(cfg)

        self.nu = cfg.ADAPTER.RoTTA.NU
        self.update_frequency = cfg.ADAPTER.RoTTA.UPDATE_FREQUENCY

        self.topk_per_cluster = cfg.ADAPTER.C2FTTA.STMEM_TOPK_CLUS
        self.max_replay = getattr(cfg.ADAPTER.C2FTTA, "STMEM_MAX_REPLAY", 32)

        self.current_instance = 0
        self._last_batch = None  #

    @torch.enable_grad()
    def forward_and_adapt(self, batch_data, model, optimizer, label=None):
        self._last_batch = batch_data

        with torch.no_grad():
            model.eval()
            self.model_ema.eval()
            ema_out = self.model_ema(batch_data)
            predict = torch.softmax(ema_out, dim=1)
            pseudo_label = torch.argmax(predict, dim=1)
            entropy = torch.sum(-predict * torch.log(predict + 1e-6), dim=1)

        if isinstance(label, dict):
            true_labels = label.get('label', None)
            domains = label.get('domain', None)
        else:
            true_labels = label
            domains = None

        for i, data in enumerate(batch_data):
            pred_i = pseudo_label[i].item()
            uncert_i = entropy[i].item()

            if true_labels is None:
                y_i = -1
            else:
                y_i = true_labels[i].item() if hasattr(true_labels, "shape") else int(true_labels)

            if domains is None:
                d_i = 0
            else:
                d_i = domains[i].item() if hasattr(domains, "shape") else int(domains)

            _ = self.mem.add_instance((data, pred_i, uncert_i, y_i, d_i))
            self.current_instance += 1

            if self.current_instance % self.update_frequency == 0:
                self.update_model(model, optimizer)

        return ema_out

    def update_model(self, model, optimizer):
        model.train()
        self.model_ema.train()

        sup_data = self.mem.get_sup_data(
            batch_samples=self._last_batch,
            topk=self.topk_per_cluster,
            max_samples=self.max_replay
        )

        if len(sup_data) > 0:
            sup_data = torch.stack(sup_data)
            strong_sup_aug = self.transform(sup_data)
            ema_sup_out = self.model_ema(sup_data)
            stu_sup_out = model(strong_sup_aug)
            l_sup = softmax_entropy(stu_sup_out, ema_sup_out).mean()

            optimizer.zero_grad()
            l_sup.backward()
            optimizer.step()

        self.update_ema_variables(self.model_ema, self.model, self.nu)

    @staticmethod
    def update_ema_variables(ema_model, model, nu):
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data[:] = (1 - nu) * ema_param[:].data[:] + nu * param[:].data[:]
        return ema_model

    def configure_model(self, model: nn.Module):
        model.requires_grad_(False)
        normlayer_names = []
        for name, sub_module in model.named_modules():
            if isinstance(sub_module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                normlayer_names.append(name)

        for name in normlayer_names:
            bn_layer = get_named_submodule(model, name)
            NewBN = RobustBN1d if isinstance(bn_layer, nn.BatchNorm1d) else RobustBN2d
            momentum_bn = NewBN(bn_layer, self.cfg.ADAPTER.RoTTA.ALPHA)
            momentum_bn.requires_grad_(True)
            set_named_submodule(model, name, momentum_bn)
        return model
