# resittac2f.py
import os
import random
from copy import deepcopy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from torch.cuda.amp import GradScaler, autocast
from typing import Tuple
from torchvision.transforms import InterpolationMode as IM

from src.adapter.base_adapter import BaseAdapter
from src.utils import ressitta_transforms as rt
from src.utils import GaussianClusterMemory  
def get_named_submodule(model, sub_name: str):
    names = sub_name.split(".")
    module = model
    for name in names:
        module = getattr(module, name)
    return module

def set_named_submodule(model, sub_name, value):
    names = sub_name.split(".")
    module = model
    for i in range(len(names)):
        if i != len(names) - 1:
            module = getattr(module, names[i])
        else:
            setattr(module, names[i], value)

def get_tta_transforms(gaussian_std: float = 0.005, soft=False, clip_inputs=False, dataset='cifar'):
    img_shape = (32, 32, 3) if 'cifar' in str(dataset).lower() else (224, 224, 3)
    print('img_shape in cotta transform', img_shape)
    n_pixels = img_shape[0]

    clip_min, clip_max = 0.0, 1.0
    p_hflip = 0.5

    tta_transforms = transforms.Compose([
        rt.Clip(0.0, 1.0),
        rt.ColorJitterPro(
            brightness=[0.8, 1.2] if soft else [0.6, 1.4],
            contrast=[0.85, 1.15] if soft else [0.7, 1.3],
            saturation=[0.75, 1.25] if soft else [0.5, 1.5],
            hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
            gamma=[0.85, 1.15] if soft else [0.7, 1.3]
        ),
        transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),
        transforms.RandomAffine(
            degrees=[-8, 8] if soft else [-15, 15],
            translate=(1 / 16, 1 / 16),
            scale=(0.95, 1.05) if soft else (0.9, 1.1),
            shear=None,
            interpolation=IM.BILINEAR,  
            fill=None
        ),
        transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]),
        transforms.CenterCrop(size=n_pixels),
        transforms.RandomHorizontalFlip(p=p_hflip),
        rt.GaussianNoise(0, gaussian_std),
        rt.Clip(clip_min, clip_max)
    ])
    return tta_transforms

class MemoryItem:
    def __init__(self, data=None, uncertainty=0, age=0):
        self.data = data
        self.uncertainty = uncertainty
        self.age = age
    def increase_age(self):
        if not self.empty():
            self.age += 1
    def get_data(self):
        return self.data, self.uncertainty, self.age
    def empty(self):
        return self.data == "empty"

class LowEntropyMemoryBankV2:
    def __init__(self, capacity, num_class, threshold, class_balance=True):
        self.capacity = capacity
        self.num_class = num_class
        self.per_class = max(self.capacity / self.num_class, 1)
        self.data: list[list[MemoryItem]] = [[] for _ in range(self.num_class)]
        self.threshold = threshold
        self.class_balance = class_balance
    def get_occupancy(self):
        return sum(len(data_per_cls) for data_per_cls in self.data)
    def per_class_dist(self):
        return [len(class_list) for class_list in self.data]
    def get_majority_classes(self):
        per_class_dist = self.per_class_dist()
        max_occupied = max(per_class_dist)
        return [i for i, occupied in enumerate(per_class_dist) if occupied == max_occupied]
    def add_instance(self, instance):
        assert (len(instance) == 3)
        self.add_age()
        x, prediction, uncertainty = instance
        new_item = MemoryItem(data=x, uncertainty=uncertainty, age=1)
        if uncertainty < self.threshold:
            if self.get_occupancy() >= self.capacity:
                if self.class_balance:
                    majority_classes = self.get_majority_classes()
                    random_cls = random.choice(majority_classes)
                    self.data[random_cls].pop(random.randint(0, len(self.data[random_cls]) - 1))
                else:
                    class_occupied = self.per_class_dist()
                    non_empty_classes = [i for i, occupied in enumerate(class_occupied) if occupied > 0]
                    random_cls = random.choice(non_empty_classes)
                    self.data[random_cls].pop(random.randint(0, len(self.data[random_cls]) - 1))
            self.data[prediction].append(new_item)
    def add_age(self):
        for class_list in self.data:
            for item in class_list:
                item.increase_age()
        return
    def get_memory(self):
        tmp_data, tmp_age = [], []
        for class_list in self.data:
            for item in class_list:
                tmp_data.append(item.data)
                tmp_age.append(item.age)
        tmp_age = [x / self.capacity for x in tmp_age]
        return tmp_data, tmp_age

class MomentumBN(nn.Module):
    def __init__(self, bn_layer: nn.BatchNorm2d, momentum, lambda_bn_d, lambda_bn_w):
        super().__init__()
        self.num_features = bn_layer.num_features
        self.momentum = float(momentum)

        if bn_layer.track_running_stats and bn_layer.running_var is not None and bn_layer.running_mean is not None:
            self.register_buffer("source_mean", bn_layer.running_mean.detach().clone())
            self.register_buffer("source_var",  bn_layer.running_var.detach().clone())
            self.register_buffer("target_mean", bn_layer.running_mean.detach().clone())
            self.register_buffer("target_var",  bn_layer.running_var.detach().clone())
            self.register_buffer("source_num",  bn_layer.num_batches_tracked.detach().clone()
                                 if isinstance(bn_layer.num_batches_tracked, torch.Tensor)
                                 else torch.tensor(bn_layer.num_batches_tracked, dtype=torch.long))

        self.weight = nn.Parameter(bn_layer.weight.detach().clone())
        self.bias   = nn.Parameter(bn_layer.bias.detach().clone())

        self.register_buffer("source_weight", bn_layer.weight.detach().clone())
        self.register_buffer("source_bias",   bn_layer.bias.detach().clone())

        self.eps = float(bn_layer.eps)
        self.lambda_bn_d = float(lambda_bn_d)
        self.lambda_bn_w = float(lambda_bn_w)

    def forward(self, x):
        raise NotImplementedError

    def get_soft_alignment_loss_weight(self):
        return torch.sum((self.weight - self.source_weight) ** 2) + torch.sum((self.bias - self.source_bias) ** 2)

    @torch.no_grad()
    def regularize_statistics(self):
        gradient_mean = 2 * (self.target_mean - self.source_mean)

        target_std = torch.sqrt(self.target_var + self.eps)
        source_std = torch.sqrt(self.source_var + self.eps)
        gradient_std = 2 * target_std - 2 * source_std

        target_std = target_std - self.lambda_bn_d * gradient_std

        self.target_mean.copy_(self.target_mean - self.lambda_bn_d * gradient_mean)
        self.target_var.copy_(target_std ** 2)


class SoftAlignmentBN1d(MomentumBN):
    def forward(self, x):
        if self.training:
            b_var, b_mean = torch.var_mean(x, dim=0, unbiased=False, keepdim=False)  # (C,)
            mean = (1 - self.momentum) * self.target_mean + self.momentum * b_mean
            var  = (1 - self.momentum) * self.target_var  + self.momentum * b_var
            self.target_mean.copy_(mean.detach())
            self.target_var.copy_(var.detach())
            mean, var = mean.view(1, -1), var.view(1, -1)
        else:
            mean, var = self.target_mean.view(1, -1), self.target_var.view(1, -1)

        x = (x - mean) / torch.sqrt(var + self.eps)
        weight = self.weight.view(1, -1)
        bias   = self.bias.view(1, -1)
        return x * weight + bias


class SoftAlignmentBN2d(MomentumBN):
    def forward(self, x):
        if self.training:
            with torch.no_grad():
                b_var, b_mean = torch.var_mean(x, dim=[0, 2, 3], unbiased=False, keepdim=False)  # (C,)
                mean = (1 - self.momentum) * self.target_mean + self.momentum * b_mean
                var  = (1 - self.momentum) * self.target_var  + self.momentum * b_var
                self.target_mean.copy_(mean)
                self.target_var.copy_(var)
            mean, var = self.target_mean, self.target_var
        else:
            mean, var = self.target_mean, self.target_var

        return F.batch_norm(x, mean, var, self.weight, self.bias, False, 0.0, self.eps)

def get_batch_norm_modules(model):
    bn_modules = []
    for _, sub_module in model.named_modules():
        if isinstance(sub_module, SoftAlignmentBN1d) or isinstance(sub_module, SoftAlignmentBN2d):
            bn_modules.append(sub_module)
    return bn_modules

@torch.jit.script
def softmax_entropy(x, x_ema):
    return -(x_ema.softmax(1) * x.log_softmax(1)).sum(1)

class ResiTTAC2F(BaseAdapter):

    def __init__(self, cfg, model, optimizer):
        super(ResiTTAC2F, self).__init__(cfg, model, optimizer)

        # —— 读取 ResiTTA 超参
        def _get(path_list, default):
            cur = cfg
            for key in path_list:
                if not hasattr(cur, key):
                    return default
                cur = getattr(cur, key)
            return cur

        self.num_class = int(_get(["CORRUPTION", "NUM_CLASS"], 1000))
        self.dataset = str(_get(["CORRUPTION", "DATASET"], "imagenet"))
        self.bn_alpha = float(_get(["ADAPTER", "RESITTA", "BN_ALPHA"], 0.05))
        self.lambda_bn_d = float(_get(["ADAPTER", "RESITTA", "LAMBDA_BN_D"], 0.01))
        self.lambda_bn_w = float(_get(["ADAPTER", "RESITTA", "LAMBDA_BN_W"], 0.0))
        self.lambda_bn_d_ema = self.lambda_bn_d

        self.capacity = int(_get(["ADAPTER", "RESITTA", "CAPACITY"], 64))
        self.update_frequency = int(_get(["ADAPTER", "RESITTA", "UPDATE_FREQUENCY"], self.capacity))
        self.steps = int(_get(["ADAPTER", "RESITTA", "STEPS"], 1))
        self.e_margin = float(_get(["ADAPTER", "RESITTA", "E_MARGIN"], 0.4))
        self.class_balance = bool(_get(["ADAPTER", "RESITTA", "CLASS_BALANCE"], True))
        self.ema_nu = float(_get(["ADAPTER", "RESITTA", "EMA_NU"], 1e-3))

        self.stmem_capacity   = int(_get(["ADAPTER", "C2FTTA", "STMEM_CAPACITY"], 16))
        self.stmem_max_clus   = int(_get(["ADAPTER", "C2FTTA", "STMEM_MAX_CLUS"], 5))
        self.stmem_topk_clus  = int(_get(["ADAPTER", "C2FTTA", "STMEM_TOPK_CLUS"], 3))
        self.base_threshold   = float(_get(["ADAPTER", "C2FTTA", "BASE_THRESHOLD"], 1.0))
        self.lambda_t = float(_get(["ADAPTER", "RoTTA", "LAMBDA_T"], 1.0))
        self.lambda_u = float(_get(["ADAPTER", "RoTTA", "LAMBDA_U"], 1.0))

        print("[ResiTTAC2F][CONFIG] "
              f"dataset={self.dataset} num_class={self.num_class} "
              f"BN_ALPHA={self.bn_alpha} LAMBDA_BN_D={self.lambda_bn_d} LAMBDA_BN_W={self.lambda_bn_w} "
              f"CAPACITY={self.capacity} UPDATE_FREQ={self.update_frequency} STEPS={self.steps} "
              f"E_MARGIN={self.e_margin} CLASS_BALANCE={self.class_balance} EMA_NU={self.ema_nu} | "
              f"STMEM_CAP={self.stmem_capacity} STMEM_MAX_CLUS={self.stmem_max_clus} "
              f"STMEM_TOPK_CLUS={self.stmem_topk_clus} BASE_TH={self.base_threshold} "
              f"LAMBDA_T={self.lambda_t} LAMBDA_U={self.lambda_u}")

        self._opt = None
        self.optimizer = optimizer

        model_ema = deepcopy(self.model)
        for p in model_ema.parameters():
            p.detach_()
        self.model_ema = model_ema


        self.bn_modules_ema = get_batch_norm_modules(self.model_ema)
        for bn in self.bn_modules_ema:
            bn.lambda_bn_d = self.lambda_bn_d

        threshold = self.e_margin * math.log(max(self.num_class, 2))
        self.mem = LowEntropyMemoryBankV2(
            capacity=self.capacity,
            num_class=self.num_class,
            threshold=threshold,
            class_balance=self.class_balance
        )

        self.gcmem = GaussianClusterMemory.GaussianClusterMemory(
            capacity=self.stmem_capacity,
            num_class=self.num_class,
            lambda_t=self.lambda_t,
            lambda_u=self.lambda_u,
            max_bank_num=self.stmem_max_clus,
            base_threshold=self.base_threshold,
        )

        self.transform = get_tta_transforms(dataset=self.dataset)

        self.update_frequency = max(1, self.update_frequency)
        self.num_instance = 0

        # AMP scaler
        self.scaler = GradScaler(enabled=torch.cuda.is_available())

    def configure_model(self, model: nn.Module):
        _bn_alpha = getattr(self, "bn_alpha", 0.05)
        _lambda_bn_d = getattr(self, "lambda_bn_d", 0.01)
        _lambda_bn_w = getattr(self, "lambda_bn_w", 0.0)

        model.requires_grad_(False)
        normlayer_names = []
        for name, sub_module in model.named_modules():
            if isinstance(sub_module, nn.BatchNorm1d) or isinstance(sub_module, nn.BatchNorm2d):
                normlayer_names.append(name)

        self.bn_modules = []
        for name in normlayer_names:
            bn_layer = get_named_submodule(model, name)
            if isinstance(bn_layer, nn.BatchNorm1d):
                NewBN = SoftAlignmentBN1d
            elif isinstance(bn_layer, nn.BatchNorm2d):
                NewBN = SoftAlignmentBN2d
            else:
                raise RuntimeError()
            momentum_bn = NewBN(bn_layer, _bn_alpha, _lambda_bn_d, _lambda_bn_w)
            momentum_bn.requires_grad_(True)
            set_named_submodule(model, name, momentum_bn)
            self.bn_modules.append(momentum_bn)

        self.bn_module_names = normlayer_names
        return model

    def _ensure_optimizer(self):
        if hasattr(self, "_opt") and self._opt is not None and hasattr(self._opt, "zero_grad"):
            return self._opt
        opt = self.optimizer
        if hasattr(opt, "zero_grad"):  
            self._opt = opt
            return self._opt
        if callable(opt):  
            self._opt = opt(self.model.parameters())
            if hasattr(self._opt, "zero_grad"):
                return self._opt
        raise TypeError(f"optimizer:{type(self.optimizer)}")

    @staticmethod
    def update_ema_variables(ema_model, model, nu):
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data[:] = (1 - nu) * ema_param[:].data[:] + nu * param[:].data[:]
        return ema_model

    @torch.enable_grad()
    def forward_and_adapt(self, batch_data, model, optimizer, label):

        _ = optimizer  
        self._ensure_optimizer()

        x = batch_data  # (N, C, H, W)
        outputs = None
        for _ in range(self.steps):
            outputs = self._forward_and_adapt_once(x, self.model, label)
        return outputs

    def _extract_label_domain(self, label, i):
        gt = -1
        dm = -1
        try:
            if isinstance(label, dict):
                if "label" in label and isinstance(label["label"], torch.Tensor):
                    gt = int(label["label"][i].item())
                if "domain" in label and isinstance(label["domain"], torch.Tensor):
                    dm = int(label["domain"][i].item())
        except Exception:
            pass
        return gt, dm

    def _forward_and_adapt_once(self, batch_data, model, label):
        with torch.no_grad():
            model.eval()
            self.model_ema.eval()
            ema_out = self.model_ema(batch_data)
            predict = torch.softmax(ema_out, dim=1)
            pseudo_label = torch.argmax(predict, dim=1)
            entropy = torch.sum(-predict * torch.log(predict + 1e-6), dim=1)

        for i, data in enumerate(batch_data):
            p_l = int(pseudo_label[i].item())
            uncert = float(entropy[i].item())
            gt_i, dm_i = self._extract_label_domain(label, i)

            try:
                self.gcmem.add_instance((data, p_l, uncert, gt_i, dm_i))
            except Exception:
                pass

            self.mem.add_instance((data, p_l, uncert))

            self.num_instance += 1
            if self.num_instance % self.update_frequency == 0:
                sup_data = self._select_support_from_gcm_or_mem(batch_data)
                if sup_data is not None and len(sup_data) > 0:
                    self.update_model(model, sup_data=sup_data)
                else:
                    self.update_model(model, sup_data=None)

        return ema_out

    def _select_support_from_gcm_or_mem(self, batch_data):
        device = next(self.model.parameters()).device
        try:
            data_tensor = batch_data if isinstance(batch_data, torch.Tensor) else torch.stack(batch_data)
            sup_data_short = self.gcmem.get_sup_data(data_tensor, topk=self.stmem_topk_clus)
            if isinstance(sup_data_short, list) and len(sup_data_short) > 0:
                return torch.stack(sup_data_short).to(device)
        except Exception:
            pass

        sup_data_mem, _ = self.mem.get_memory()
        if len(sup_data_mem) > 0:
            return torch.stack(sup_data_mem).to(device)
        return None

    def update_model(self, model, sup_data=None):

        model.train()
        self.model_ema.train()
        device = next(model.parameters()).device

        if sup_data is None:
            mem_data, _ = self.mem.get_memory()
            if len(mem_data) == 0:
                return
            sup_data = torch.stack(mem_data).to(device)

        optimizer = self._ensure_optimizer()

        with autocast(enabled=torch.cuda.is_available(), dtype=torch.float16):
            strong_sup_aug = torch.stack([self.transform(img) for img in sup_data]).to(device)

            ema_sup_out = self.model_ema(sup_data)
            stu_sup_out = model(strong_sup_aug)

            l_sup = softmax_entropy(stu_sup_out, ema_sup_out).mean()

            if self.lambda_bn_w > 0:
                l_soft_alignment = []
                for bn_module in self.bn_modules:
                    l_soft_alignment.append(bn_module.get_soft_alignment_loss_weight())
                l_soft_alignment = torch.stack(l_soft_alignment).sum()
            else:
                l_soft_alignment = torch.tensor(0.0, device=device)

            loss = l_sup + l_soft_alignment * self.lambda_bn_w

        optimizer.zero_grad()
        self.scaler.scale(loss).backward()
        self.scaler.step(optimizer)
        self.scaler.update()

        for bn_module in self.bn_modules:
            bn_module.regularize_statistics()
        for bn_module in self.bn_modules_ema:
            bn_module.regularize_statistics()

        self.update_ema_variables(self.model_ema, self.model, self.ema_nu)

    def visualize_support_data(self, short_data, long_data, step, out_dir):

        try:
            import numpy as np
            from sklearn.decomposition import PCA
            import matplotlib.pyplot as plt

            if not short_data and not long_data:
                return
            def _to_feat(t):
                if t.ndim == 3:
                    return torch.mean(t, dim=(1, 2)).cpu().numpy()
                elif t.ndim == 4:
                    return torch.mean(t, dim=(2, 3)).squeeze(0).cpu().numpy()
                else:
                    return t.view(-1).cpu().numpy()

            feats, colors = [], []
            for d in short_data:
                feats.append(_to_feat(d)); colors.append(0)
            for d in long_data:
                feats.append(_to_feat(d)); colors.append(1)

            feats = np.stack(feats)
            colors = np.array(colors)
            pca = PCA(n_components=2)
            xy = pca.fit_transform(feats)

            os.makedirs(out_dir, exist_ok=True)
            plt.figure(figsize=(7,6))
            m0 = colors == 0; m1 = colors == 1
            plt.scatter(xy[m0,0], xy[m0,1], alpha=0.7, label="Short-Term")
            plt.scatter(xy[m1,0], xy[m1,1], alpha=0.7, label="Long-Term")
            plt.legend(); plt.title(f"Support Data @ Step {step}")
            plt.tight_layout()
            plt.savefig(os.path.join(out_dir, f"support_vis_step_{step:05d}.png"))
            plt.close()
        except Exception as e:
            print(f"[ResiTTAC2F][VIS][WARN] {e}")
