# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import model_utils

# -----------------------------
# Adaptive Stable Loss (A/B/C)
# -----------------------------
class AdaptiveStableLoss(nn.Module):
    """
    A: running sigma_ref via ref_beta when sl_use_running_ref=1
    B: gate activation when sigma_ema <= sl_excess_gate * sigma_ref (if >1.0)
    C: delta_mode 'abs' uses delta_frac as absolute; 'frac' scales by volatility
    """
    def __init__(self,
                 alpha=0.10,
                 beta=0.10,
                 delta_frac=0.20,
                 lambda_base=0.50,
                 lambda_min=0.0,
                 lambda_max=2.0,
                 warmup_steps=200,
                 eps=1e-8,
                 sl_use_running_ref=False,
                 sl_ref_beta=0.01,
                 sl_excess_gate=1.0,
                 sl_delta_mode="abs"):
        super().__init__()
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.delta_frac = float(delta_frac)
        self.lambda_base = float(lambda_base)
        self.lambda_min = float(lambda_min)
        self.lambda_max = float(lambda_max)
        self.warmup_steps = int(warmup_steps)
        self.eps = float(eps)
        self.sl_use_running_ref = bool(sl_use_running_ref)
        self.sl_ref_beta = float(sl_ref_beta)
        self.sl_excess_gate = float(sl_excess_gate)
        self.sl_delta_mode = str(sl_delta_mode).lower()

        self.register_buffer("l_ema", torch.tensor(0.0))
        self.register_buffer("sigma_ema", torch.tensor(0.0))
        self.register_buffer("sigma_ref", torch.tensor(0.0))
        self.register_buffer("warmup_count", torch.tensor(0))
        self.register_buffer("last_lambda", torch.tensor(0.0))
        self.register_buffer("last_delta", torch.tensor(0.0))
        self._initialized = False

    @staticmethod
    def _huber(delta, thresh):
        # robust quadratic/linear; scaled to keep gradient sensible vs thresh
        absd = delta.abs()
        quad = 0.5 * (delta * delta) / max(thresh, 1e-12)
        lin = absd - 0.5 * thresh
        return torch.where(absd <= thresh, quad, lin)

    @torch.no_grad()
    def _init_if_needed(self, loss_value: torch.Tensor):
        if not self._initialized:
            v = loss_value.detach().to(self.l_ema.device)
            self.l_ema.fill_(v)
            self.sigma_ema.zero_()
            self.sigma_ref.zero_()
            self.warmup_count.zero_()
            self.last_lambda.zero_()
            self.last_delta.zero_()
            self._initialized = True

    def forward(self, base_loss: torch.Tensor) -> torch.Tensor:
        assert base_loss.dim() == 0, "AdaptiveStableLoss expects a scalar batch loss"
        self._init_if_needed(base_loss)

        l_prev = self.l_ema.detach()
        delta_t = base_loss - l_prev

        # Choose delta (C)
        if self.sl_delta_mode == "frac":
            scale_src = self.sigma_ref if self.sigma_ref.item() > 0.0 else self.sigma_ema
            delta = max(float(scale_src.item()) * self.delta_frac, 1e-6)
        else:
            delta = max(self.delta_frac, 1e-6)

        huber_t = self._huber(delta_t, delta)

        # Update volatility + reference
        with torch.no_grad():
            self.sigma_ema.mul_(1.0 - self.beta).add_(self.beta * delta_t.detach().abs())
            self.warmup_count.add_(1)
            if int(self.warmup_count.item()) == self.warmup_steps and self.sigma_ref.item() == 0.0:
                self.sigma_ref.copy_(torch.clamp(self.sigma_ema, min=self.eps))
            if self.sl_use_running_ref and self.sigma_ref.item() > 0.0:
                self.sigma_ref.mul_(1.0 - self.sl_ref_beta).add_(self.sl_ref_beta * self.sigma_ema)

        denom = self.sigma_ref if self.sigma_ref.item() > 0.0 else self.sigma_ema
        if self.sl_excess_gate > 1.0 and self.sigma_ref.item() > 0.0:
            active = (self.sigma_ema > self.sl_excess_gate * denom).float()
            gain = (self.sigma_ema / (denom + self.eps)) * active
        else:
            gain = self.sigma_ema / (denom + self.eps)

        lam_t = torch.clamp(self.lambda_base * gain, self.lambda_min, self.lambda_max)
        sl_t = lam_t * huber_t

        with torch.no_grad():
            self.l_ema.mul_(1.0 - self.alpha).add_(self.alpha * base_loss.detach())
            self.last_lambda.copy_(lam_t)
            self.last_delta.copy_(delta_t.detach())

        return sl_t


# -----------------------------
# Adaptive, label-aware VPL
# -----------------------------
class AdaptiveLabelVariancePenalty(nn.Module):
    """
    stat:
      - 'vector': variance of full logit vectors per class, then mean across dims/classes
      - 'true'  : variance of true-class logit per class, then average across classes
    """
    def __init__(self,
                 lambda_base: float = 1.0,
                 scale: float = 1.0,
                 alpha: float = 0.10,
                 warmup_steps: int = 100,
                 lambda_min: float = 0.0,
                 lambda_max: float = 2.0,
                 use_entropy: bool = False,
                 entropy_scale: float = 0.5,
                 eps: float = 1e-8,
                 stat: str = "vector"):
        super().__init__()
        self.lambda_base = float(lambda_base)
        self.scale = float(scale)
        self.alpha = float(alpha)
        self.warmup_steps = int(warmup_steps)
        self.lambda_min = float(lambda_min)
        self.lambda_max = float(lambda_max)
        self.use_entropy = bool(use_entropy)
        self.entropy_scale = float(entropy_scale)
        self.eps = float(eps)
        self.stat = str(stat).lower()
        assert self.stat in ("vector", "true")

        self.register_buffer("v_ema", torch.tensor(0.0))
        self.register_buffer("v_ref", torch.tensor(0.0))
        self.register_buffer("warmup_count", torch.tensor(0))
        self.register_buffer("last_lambda", torch.tensor(0.0))
        self.register_buffer("last_vbatch", torch.tensor(0.0))
        self._initialized = False

    @torch.no_grad()
    def reset_state(self):
        self.v_ema.zero_(); self.v_ref.zero_()
        self.warmup_count.zero_()
        self.last_lambda.zero_(); self.last_vbatch.zero_()
        self._initialized = False

    @torch.no_grad()
    def _init_if_needed(self, v0: torch.Tensor):
        if not self._initialized:
            self.v_ema.fill_(v0)
            self.v_ref.zero_()
            self.warmup_count.zero_()
            self.last_lambda.zero_()
            self.last_vbatch.copy_(v0)
            self._initialized = True

    def _v_batch_vector(self, logits, labels):
        B, C = logits.shape
        vals = []
        for c in range(C):
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if idx.numel() >= 2:
                lv = logits.index_select(0, idx).var(dim=0, unbiased=False).mean()
                vals.append(lv)
        return torch.stack(vals).mean() if vals else logits.new_tensor(0.0)

    def _v_batch_true(self, logits, labels):
        B, C = logits.shape
        vals = []
        for c in range(C):
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if idx.numel() >= 2:
                lv = logits.index_select(0, idx)[:, c].var(unbiased=False)
                vals.append(lv)
        return torch.stack(vals).mean() if vals else logits.new_tensor(0.0)

    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        assert logits.dim() == 2
        v_batch = self._v_batch_vector(logits, labels) if self.stat == "vector" else self._v_batch_true(logits, labels)
        self._init_if_needed(v_batch.detach())

        with torch.no_grad():
            self.v_ema.mul_(1.0 - self.alpha).add_(self.alpha * v_batch.detach())
            self.warmup_count.add_(1)
            if int(self.warmup_count.item()) >= self.warmup_steps and self.v_ref.item() == 0.0:
                self.v_ref.copy_(torch.clamp(self.v_ema, min=self.eps))
            self.last_vbatch.copy_(v_batch.detach())

        denom = self.v_ref if self.v_ref.item() > 0.0 else self.v_ema
        lam = self.lambda_base * (self.v_ema / (denom + self.eps))

        if self.use_entropy:
            with torch.no_grad():
                p = torch.softmax(logits, dim=1)
                ent = -(p * p.clamp_min(self.eps).log()).sum(dim=1).mean()
                ent_norm = ent / max(math.log(max(logits.shape[1], 2)), self.eps)
            lam = lam * (1.0 + self.entropy_scale * ent_norm)

        lam = torch.clamp(lam, self.lambda_min, self.lambda_max)
        self.last_lambda.copy_(lam)
        return self.scale * lam * v_batch


# -----------------------------
# VGG backbone with SL/VPL
# -----------------------------
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self,
                 vgg_name: str,
                 init_strategy='he',
                 num_classes=10,
                 # VPL/SL knobs
                 stable_weight=0.0,
                 vpl_weight_decay=0.1,
                 vpl_weight=0.1,
                 # SL defaults
                 sl_alpha=0.10, sl_beta=0.10, sl_delta=0.20,
                 sl_lambda_base=None, sl_lambda_min=0.0, sl_lambda_max=2.0,
                 sl_warmup_steps=200, sl_eps=1e-8,
                 # A/B/C extras (defaults off)
                 sl_use_running_ref=False, sl_ref_beta=0.01,
                 sl_excess_gate=1.0, sl_delta_mode="abs",
                 # VPL stat
                 vpl_stat: str = "vector"):
        super().__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, num_classes)
        self.init_strategy = init_strategy
        self._initialize_weights()

        # --- SL ---
        if sl_lambda_base is None or (isinstance(sl_lambda_base, float) and sl_lambda_base == 0.0):
            sl_lambda_base = stable_weight if (stable_weight is not None and stable_weight != 0.0) else 0.50
        self.adaptive_stable_loss = AdaptiveStableLoss(
            alpha=sl_alpha, beta=sl_beta, delta_frac=sl_delta,
            lambda_base=sl_lambda_base, lambda_min=sl_lambda_min, lambda_max=sl_lambda_max,
            warmup_steps=sl_warmup_steps, eps=sl_eps,
            sl_use_running_ref=sl_use_running_ref, sl_ref_beta=sl_ref_beta,
            sl_excess_gate=sl_excess_gate, sl_delta_mode=sl_delta_mode
        )
        self.stable_mod = self.adaptive_stable_loss

        # --- VPL ---
        self.variance_penalty = AdaptiveLabelVariancePenalty(
            lambda_base=1.0, scale=float(vpl_weight_decay),
            alpha=0.10, warmup_steps=100, lambda_min=0.0, lambda_max=2.0,
            use_entropy=False, entropy_scale=0.5, eps=1e-8, stat=vpl_stat
        )
        self.vpl_weight = float(vpl_weight)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        return self.classifier(out)

    def _make_layers(self, cfg_list):
        layers, in_c = [], 3
        for x in cfg_list:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_c, x, 3, padding=1, bias=False),
                           nn.BatchNorm2d(x), nn.ReLU(inplace=True)]
                in_c = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if self.init_strategy == 'he':
                    init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif self.init_strategy == 'xavier':
                    init.xavier_normal_(m.weight)
                elif self.init_strategy == 'custom_uniform':
                    init.uniform_(m.weight, -0.0085, 0.0085)
                elif self.init_strategy == 'custom_xavier':
                    init.xavier_normal_(m.weight); m.weight.data.clamp_(-0.0085, 0.0085)
                elif self.init_strategy == 'custom_kaiming':
                    init.kaing_normal_(m.weight, mode='fan_out', nonlinearity='relu')  # typo safe-guard
                    init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    m.weight.data.clamp_(-0.0085, 0.0085)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1); init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, 0, 0.01); init.constant_(m.bias, 0)

    def configure_adaptive_stable(self, alpha=None, beta=None, delta_frac=None,
                                  lambda_base=None, lambda_min=None, lambda_max=None,
                                  warmup_steps=None, use_running_ref=None, eps=None,
                                  sl_ref_beta=None, sl_excess_gate=None, sl_delta_mode=None):
        mod = self.adaptive_stable_loss
        if alpha is not None: mod.alpha = float(alpha)
        if beta  is not  None: mod.beta  = float(beta)
        if eps   is not  None: mod.eps   = float(eps)
        if lambda_base is not None: mod.lambda_base = float(lambda_base)
        if lambda_min  is not None: mod.lambda_min  = float(lambda_min)
        if lambda_max  is not None: mod.lambda_max  = float(lambda_max)
        if warmup_steps is not None: mod.warmup_steps = int(warmup_steps)
        if delta_frac is not None: mod.delta_frac = float(delta_frac)
        if use_running_ref is not None: mod.sl_use_running_ref = bool(use_running_ref)
        if sl_ref_beta is not None: mod.sl_ref_beta = float(sl_ref_beta)
        if sl_excess_gate is not None: mod.sl_excess_gate = float(sl_excess_gate)
        if sl_delta_mode is not None: mod.sl_delta_mode = str(sl_delta_mode).lower()
        return self


# Factories (trainer expects these names)
def vgg11(flags=None, **kwargs):
    old = model_utils.set_rng_state(flags); m = VGG('VGG11', **kwargs); model_utils.restore_rng_state(old); return m
def vgg13(flags=None, **kwargs):
    old = model_utils.set_rng_state(flags); m = VGG('VGG13', **kwargs); model_utils.restore_rng_state(old); return m
def vgg16(flags=None, **kwargs):
    old = model_utils.set_rng_state(flags); m = VGG('VGG16', **kwargs); model_utils.restore_rng_state(old); return m
def vgg19(flags=None, **kwargs):
    old = model_utils.set_rng_state(flags); m = VGG('VGG19', **kwargs); model_utils.restore_rng_state(old); return m
