# resitta.py
import random
from copy import deepcopy
import math
import os

import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

from torch.cuda.amp import GradScaler, autocast
from torch import Tensor
from typing import Tuple
from collections import OrderedDict

# ====== 关键：与工程接口对齐 ======
# 继承 BaseAdapter，接口与 C2FTTA 一致
from src.adapter.base_adapter import BaseAdapter
from src.utils import ressitta_transforms as rt
from torchvision.transforms import InterpolationMode as IM


def get_tta_transforms(gaussian_std: float = 0.005, soft=False, clip_inputs=False, dataset='cifar'):
    img_shape = (32, 32, 3) if 'cifar' in str(dataset).lower() else (224, 224, 3)
    print('img_shape in cotta transform', img_shape)
    n_pixels = img_shape[0]

    clip_min, clip_max = 0.0, 1.0
    p_hflip = 0.5

    tta_transforms = transforms.Compose([
        rt.Clip(0.0, 1.0),
        rt.ColorJitterPro(
            brightness=[0.8, 1.2] if soft else [0.6, 1.4],
            contrast=[0.85, 1.15] if soft else [0.7, 1.3],
            saturation=[0.75, 1.25] if soft else [0.5, 1.5],
            hue=[-0.03, 0.03] if soft else [-0.06, 0.06],
            gamma=[0.85, 1.15] if soft else [0.7, 1.3]
        ),
        transforms.Pad(padding=int(n_pixels / 2), padding_mode='edge'),
        transforms.RandomAffine(
            degrees=[-8, 8] if soft else [-15, 15],
            translate=(1 / 16, 1 / 16),
            scale=(0.95, 1.05) if soft else (0.9, 1.1),
            shear=None,
            interpolation=IM.BILINEAR,  # 用枚举避免 warning
            fill=None
        ),
        transforms.GaussianBlur(kernel_size=5, sigma=[0.001, 0.25] if soft else [0.001, 0.5]),
        transforms.CenterCrop(size=n_pixels),
        transforms.RandomHorizontalFlip(p=p_hflip),
        rt.GaussianNoise(0, gaussian_std),
        rt.Clip(clip_min, clip_max)
    ])
    return tta_transforms


# ========= 原版 Low-Entropy Memory（保持不变） =========
class MemoryItem:
    def __init__(self, data=None, uncertainty=0, age=0):
        self.data = data
        self.uncertainty = uncertainty
        self.age = age

    def increase_age(self):
        if not self.empty():
            self.age += 1

    def get_data(self):
        return self.data, self.uncertainty, self.age

    def empty(self):
        return self.data == "empty"


class LowEntropyMemoryBankV2:
    def __init__(self, capacity, num_class, threshold, class_balance=True):
        self.capacity = capacity
        self.num_class = num_class
        self.per_class = max(self.capacity / self.num_class, 1)

        self.data: list[list[MemoryItem]] = [[] for _ in range(self.num_class)]
        self.threshold = threshold
        self.class_balance = class_balance

    def get_occupancy(self):
        return sum(len(data_per_cls) for data_per_cls in self.data)

    def per_class_dist(self):
        return [len(class_list) for class_list in self.data]

    def get_majority_classes(self):
        per_class_dist = self.per_class_dist()
        max_occupied = max(per_class_dist)
        return [i for i, occupied in enumerate(per_class_dist) if occupied == max_occupied]

    def add_instance(self, instance):
        assert (len(instance) == 3)
        self.add_age()
        x, prediction, uncertainty = instance
        new_item = MemoryItem(data=x, uncertainty=uncertainty, age=1)

        if uncertainty < self.threshold:
            if self.get_occupancy() >= self.capacity:
                # remove from majority classes or any non-empty class
                if self.class_balance:
                    majority_classes = self.get_majority_classes()
                    random_cls = random.choice(majority_classes)
                    self.data[random_cls].pop(random.randint(0, len(self.data[random_cls]) - 1))
                else:
                    class_occupied = self.per_class_dist()
                    non_empty_classes = [i for i, occupied in enumerate(class_occupied) if occupied > 0]
                    random_cls = random.choice(non_empty_classes)
                    self.data[random_cls].pop(random.randint(0, len(self.data[random_cls]) - 1))

            self.data[prediction].append(new_item)

    def add_age(self):
        for class_list in self.data:
            for item in class_list:
                item.increase_age()
        return

    def get_memory(self):
        tmp_data = []
        tmp_age = []

        for class_list in self.data:
            for item in class_list:
                tmp_data.append(item.data)
                tmp_age.append(item.age)

        tmp_age = [x / self.capacity for x in tmp_age]

        return tmp_data, tmp_age


# ========= BN 软对齐模块（修复：注册 buffer + 参数，自动随设备迁移）=========
class MomentumBN(nn.Module):
    def __init__(self, bn_layer: nn.BatchNorm2d, momentum, lambda_bn_d, lambda_bn_w):
        super().__init__()
        self.num_features = bn_layer.num_features
        self.momentum = float(momentum)

        # --- 把运行统计量与“源快照”注册为 buffer（会随 .to(device) 自动迁移）
        if bn_layer.track_running_stats and bn_layer.running_var is not None and bn_layer.running_mean is not None:
            self.register_buffer("source_mean", bn_layer.running_mean.detach().clone())
            self.register_buffer("source_var",  bn_layer.running_var.detach().clone())
            self.register_buffer("target_mean", bn_layer.running_mean.detach().clone())
            self.register_buffer("target_var",  bn_layer.running_var.detach().clone())
            # 不是必须，但保留以兼容原逻辑
            self.register_buffer("source_num",  bn_layer.num_batches_tracked.detach().clone()
                                 if isinstance(bn_layer.num_batches_tracked, torch.Tensor)
                                 else torch.tensor(bn_layer.num_batches_tracked, dtype=torch.long))

        # --- 可学习的 BN 仿射参数：做成 Parameter
        self.weight = nn.Parameter(bn_layer.weight.detach().clone())
        self.bias   = nn.Parameter(bn_layer.bias.detach().clone())

        # --- 源权重快照用于软对齐：注册为 buffer
        self.register_buffer("source_weight", bn_layer.weight.detach().clone())
        self.register_buffer("source_bias",   bn_layer.bias.detach().clone())

        self.eps = float(bn_layer.eps)
        self.lambda_bn_d = float(lambda_bn_d)
        self.lambda_bn_w = float(lambda_bn_w)

    def forward(self, x):
        raise NotImplementedError

    def get_soft_alignment_loss_weight(self):
        # 与源权重的 L2
        return torch.sum((self.weight - self.source_weight) ** 2) + torch.sum((self.bias - self.source_bias) ** 2)

    @torch.no_grad()
    def regularize_statistics(self):
        # Wasserstein 软对齐的闭式“回拉”
        gradient_mean = 2 * (self.target_mean - self.source_mean)

        target_std = torch.sqrt(self.target_var + self.eps)
        source_std = torch.sqrt(self.source_var + self.eps)
        gradient_std = 2 * target_std - 2 * source_std

        target_std = target_std - self.lambda_bn_d * gradient_std

        self.target_mean.copy_(self.target_mean - self.lambda_bn_d * gradient_mean)
        self.target_var.copy_(target_std ** 2)


class SoftAlignmentBN1d(MomentumBN):
    def forward(self, x):
        # buffer/parameter 已随模块在同一设备，无需额外 .to()
        if self.training:
            b_var, b_mean = torch.var_mean(x, dim=0, unbiased=False, keepdim=False)  # (C,)
            mean = (1 - self.momentum) * self.target_mean + self.momentum * b_mean
            var  = (1 - self.momentum) * self.target_var  + self.momentum * b_var
            # 维护目标统计（保存在 buffer 里）
            self.target_mean.copy_(mean.detach())
            self.target_var.copy_(var.detach())
            mean, var = mean.view(1, -1), var.view(1, -1)
        else:
            mean, var = self.target_mean.view(1, -1), self.target_var.view(1, -1)

        x = (x - mean) / torch.sqrt(var + self.eps)
        weight = self.weight.view(1, -1)
        bias   = self.bias.view(1, -1)
        return x * weight + bias


class SoftAlignmentBN2d(MomentumBN):
    def forward(self, x):
        # 同上：所有 buffer/parameter 已与模块对齐在 x.device
        if self.training:
            with torch.no_grad():
                b_var, b_mean = torch.var_mean(x, dim=[0, 2, 3], unbiased=False, keepdim=False)  # (C,)
                mean = (1 - self.momentum) * self.target_mean + self.momentum * b_mean
                var  = (1 - self.momentum) * self.target_var  + self.momentum * b_var
                self.target_mean.copy_(mean)
                self.target_var.copy_(var)
            mean, var = self.target_mean, self.target_var
        else:
            mean, var = self.target_mean, self.target_var

        # 直接把目标统计传入 batch_norm（不更新 running stats）
        return F.batch_norm(x, mean, var, self.weight, self.bias, False, 0.0, self.eps)


def get_batch_norm_modules(model):
    bn_modules = []
    for _, sub_module in model.named_modules():
        if isinstance(sub_module, SoftAlignmentBN1d) or isinstance(sub_module, SoftAlignmentBN2d):
            bn_modules.append(sub_module)
    return bn_modules


@torch.jit.script
def softmax_entropy(x, x_ema):
    return -(x_ema.softmax(1) * x.log_softmax(1)).sum(1)


def get_named_submodule(model, sub_name: str):
    names = sub_name.split(".")
    module = model
    for name in names:
        module = getattr(module, name)
    return module


def set_named_submodule(model, sub_name, value):
    names = sub_name.split(".")
    module = model
    for i in range(len(names)):
        if i != len(names) - 1:
            module = getattr(module, names[i])
        else:
            setattr(module, names[i], value)


# ========= 适配后的 ResiTTA（接口与 C2FTTA 保持一致） =========
class ResiTTA(BaseAdapter):
    """
    保持原本 ResiTTA 的实现与行为，仅把接口改成与 C2FTTA 一致：
      - __init__(cfg, model, optimizer)
      - forward_and_adapt(self, batch_data, model, optimizer, label)
    返回：教师在当前 batch 上的 logits（Tensor）。
    """

    def __init__(self, cfg, model, optimizer):
        super(ResiTTA, self).__init__(cfg, model, optimizer)

        # ====== 从 cfg 里读取必要的超参（若没有给则用默认值）======
        # ====== 从 cfg 里读取必要的超参（若没有给则用默认值）======
        def _get_strict(path_list, default, *, key_name=None):
            """
            严格取值：如果缺键，用 default，但会打印 WARNING。
            """
            cur = cfg
            for key in path_list:
                if not hasattr(cur, key):
                    name = key_name or ".".join(path_list)
                    print(f"[ResiTTA][WARN] cfg 缺少 {name}，使用默认值 {default}")
                    return default
                cur = getattr(cur, key)
            return cur

        self.num_class     = int(_get_strict(["CORRUPTION", "NUM_CLASS"], 1000, key_name="CORRUPTION.NUM_CLASS"))
        self.dataset       = str(_get_strict(["CORRUPTION", "DATASET"], "imagenet", key_name="CORRUPTION.DATASET"))
        self.bn_alpha      = float(_get_strict(["ADAPTER", "RESITTA", "BN_ALPHA"], 0.05))
        self.lambda_bn_d   = float(_get_strict(["ADAPTER", "RESITTA", "LAMBDA_BN_D"], 0.01))
        self.lambda_bn_w   = float(_get_strict(["ADAPTER", "RESITTA", "LAMBDA_BN_W"], 0.0))
        self.lambda_bn_d_ema = self.lambda_bn_d

        self.capacity         = int(_get_strict(["ADAPTER", "RESITTA", "CAPACITY"], 64))
        self.update_frequency = int(_get_strict(["ADAPTER", "RESITTA", "UPDATE_FREQUENCY"], 64))
        self.steps            = int(_get_strict(["ADAPTER", "RESITTA", "STEPS"], 1))
        self.e_margin         = float(_get_strict(["ADAPTER", "RESITTA", "E_MARGIN"], 0.4))
        self.class_balance    = bool(_get_strict(["ADAPTER", "RESITTA", "CLASS_BALANCE"], True))
        self.ema_nu           = float(_get_strict(["ADAPTER", "RESITTA", "EMA_NU"], 1e-3))

        # —— 启动时把关键超参全部打印出来，便于你在日志里确认
        print("[ResiTTA][CONFIG] "
              f"dataset={self.dataset} num_class={self.num_class} "
              f"BN_ALPHA={self.bn_alpha} LAMBDA_BN_D={self.lambda_bn_d} LAMBDA_BN_W={self.lambda_bn_w} "
              f"CAPACITY={self.capacity} UPDATE_FREQ={self.update_frequency} STEPS={self.steps} "
              f"E_MARGIN={self.e_margin} CLASS_BALANCE={self.class_balance} EMA_NU={self.ema_nu}")

        # —— 优化器缓存（可能传进来的是工厂函数，需要第一次使用前实例化）
        self._opt = None

        # ====== 模型与优化器 ======
        self.optimizer = optimizer  # 使用 main.py / build_optimizer 返回的对象或工厂

        # ====== 在 configure_model 中替换 BN 为 SoftAlignment 版本 ======
        # BaseAdapter 在 super().__init__ 中会调用子类的 configure_model

        # ====== EMA 教师与其 BN ======
        model_ema = deepcopy(self.model)
        for p in model_ema.parameters():
            p.detach_()
        self.model_ema = model_ema
        self.bn_modules_ema = get_batch_norm_modules(self.model_ema)
        for bn in self.bn_modules_ema:
            bn.lambda_bn_d = self.lambda_bn_d_ema

        # ====== 记忆库 ======
        threshold = self.e_margin * math.log(max(self.num_class, 2))
        self.mem = LowEntropyMemoryBankV2(capacity=self.capacity,
                                          num_class=self.num_class,
                                          threshold=threshold,
                                          class_balance=self.class_balance)

        # ====== TTA 强增广 ======
        self.transform = get_tta_transforms(dataset=self.dataset)

        # 统计变量
        self.update_frequency = max(1, self.update_frequency)
        self.num_instance = 0

        # AMP scaler（CUDA 可用则启用）
        self.scaler = GradScaler(enabled=torch.cuda.is_available())

    def configure_model(self, model: nn.Module):
        """
        将模型中的 BN 层替换为 SoftAlignment 版本，并收集到 self.bn_modules。
        """
        # 兜底读取（避免顺序问题）
        _bn_alpha = getattr(self, "bn_alpha", 0.05)
        _lambda_bn_d = getattr(self, "lambda_bn_d", 0.01)
        _lambda_bn_w = getattr(self, "lambda_bn_w", 0.0)

        model.requires_grad_(False)
        normlayer_names = []
        for name, sub_module in model.named_modules():
            if isinstance(sub_module, nn.BatchNorm1d) or isinstance(sub_module, nn.BatchNorm2d):
                normlayer_names.append(name)

        self.bn_modules = []
        for name in normlayer_names:
            bn_layer = get_named_submodule(model, name)
            if isinstance(bn_layer, nn.BatchNorm1d):
                NewBN = SoftAlignmentBN1d
            elif isinstance(bn_layer, nn.BatchNorm2d):
                NewBN = SoftAlignmentBN2d
            else:
                raise RuntimeError()

            momentum_bn = NewBN(bn_layer, _bn_alpha, _lambda_bn_d, _lambda_bn_w)
            momentum_bn.requires_grad_(True)
            set_named_submodule(model, name, momentum_bn)
            self.bn_modules.append(momentum_bn)

        self.bn_module_names = normlayer_names
        return model

    def _ensure_optimizer(self):
        """
        将 self.optimizer 统一转成一个已实例化、可 .zero_grad() 的优化器并缓存到 self._opt。
        允许 self.optimizer 为：
          - 已实例化优化器对象（有 zero_grad）
          - 工厂函数：callable(params) -> optimizer
        """
        if self._opt is not None and hasattr(self._opt, "zero_grad"):
            return self._opt

        opt = self.optimizer
        # 已经是优化器对象
        if hasattr(opt, "zero_grad"):
            self._opt = opt
            return self._opt

        # 工厂：用模型参数实例化
        if callable(opt):
            self._opt = opt(self.model.parameters())
            if hasattr(self._opt, "zero_grad"):
                return self._opt

        raise TypeError(f"optimizer 无法实例化：当前类型 {type(self.optimizer)}")

    @staticmethod
    def update_ema_variables(ema_model, model, nu):
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data[:] = (1 - nu) * ema_param[:].data[:] + nu * param[:].data[:]
        return ema_model

    @torch.enable_grad()
    def forward_and_adapt(self, batch_data, model, optimizer, label):
        """
        与 C2FTTA 同名接口；main.py 会调用 tta_model(data, label={...})
        返回教师在当前 batch 上的 logits（Tensor）。
        """
        # 忽略传进来的 optimizer（兼容 BaseAdapter 签名），统一在内部管理
        _ = optimizer

        # 确保优化器已实例化（处理工厂函数）
        self._ensure_optimizer()

        x = batch_data  # NCHW
        outputs = None
        for _ in range(self.steps):
            outputs = self._forward_and_adapt_once(x, self.model)
        return outputs  # logits（教师在当前 batch 上的输出）

    def _forward_and_adapt_once(self, batch_data, model):
        # 1) 教师给伪标签与熵（无梯度）
        with torch.no_grad():
            model.eval()
            self.model_ema.eval()
            ema_out = self.model_ema(batch_data)
            predict = torch.softmax(ema_out, dim=1)
            pseudo_label = torch.argmax(predict, dim=1)
            entropy = torch.sum(- predict * torch.log(predict + 1e-6), dim=1)

        # 2) 写入记忆 & 定期触发一次更新
        for i, data in enumerate(batch_data):
            p_l = pseudo_label[i].item()
            uncert = entropy[i].item()
            current_instance = (data, p_l, uncert)
            self.mem.add_instance(current_instance)
            self.num_instance += 1

            if self.num_instance % self.update_frequency == 0:
                self.update_model(model)

        # 返回教师在当前 batch 上的 logits（供外层算精度）
        return ema_out

    def update_model(self, model):
        model.train()
        self.model_ema.train()
        device = next(model.parameters()).device

        # 从记忆库取支持数据
        sup_data, _ = self.mem.get_memory()
        if len(sup_data) == 0:
            return

        # —— 拿到已实例化的优化器
        optimizer = self._ensure_optimizer()

        # 兼容旧版 autocast（无 device_type 形参）
        with autocast(enabled=torch.cuda.is_available(), dtype=torch.float16):
            sup_data = torch.stack(sup_data).to(device)           # (N, C, H, W)

            # 强增广：对 batch 逐样本处理，兼容老版 torchvision
            strong_sup_aug = torch.stack([self.transform(img) for img in sup_data]).to(device)

            ema_sup_out = self.model_ema(sup_data)
            stu_sup_out = model(strong_sup_aug)

            # 一致性损失（教师软分布作为目标）
            l_sup = (softmax_entropy(stu_sup_out, ema_sup_out)).mean()

            # BN 权重软对齐
            if self.lambda_bn_w > 0:
                l_soft_alignment = []
                for bn_module in self.bn_modules:
                    l_soft_alignment.append(bn_module.get_soft_alignment_loss_weight())
                l_soft_alignment = torch.stack(l_soft_alignment).sum()
            else:
                l_soft_alignment = torch.tensor(0.0, device=device)

            loss = l_sup + l_soft_alignment * self.lambda_bn_w

        optimizer.zero_grad()
        self.scaler.scale(loss).backward()
        self.scaler.step(optimizer)
        self.scaler.update()

        # BN 统计回拉（Wasserstein 软对齐）
        for bn_module in self.bn_modules:
            bn_module.regularize_statistics()
        for bn_module in self.bn_modules_ema:
            bn_module.regularize_statistics()

        # EMA 同步教师
        self.update_ema_variables(self.model_ema, self.model, self.ema_nu)
