# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import model_utils


# =========================
# Adaptive Stable Loss (A/B/C)
# =========================
class AdaptiveStableLoss(nn.Module):
    """
    Batch-scalar base loss -> EMA baseline + volatility tracking + Huber.
    Supports:
      - running sigma_ref (sl_use_running_ref, sl_ref_beta)
      - excess gate (sl_excess_gate > 1.0)
      - delta mode 'abs' or 'frac' (fraction of volatility)
    """
    def __init__(
        self,
        alpha=0.10,
        beta=0.10,
        delta_frac=0.20,
        lambda_base=0.50,
        lambda_min=0.0,
        lambda_max=2.0,
        warmup_steps=200,
        eps=1e-8,
        sl_use_running_ref=False,
        sl_ref_beta=0.01,
        sl_excess_gate=1.0,
        sl_delta_mode="abs",
    ):
        super().__init__()
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.delta_frac = float(delta_frac)
        self.lambda_base = float(lambda_base)
        self.lambda_min = float(lambda_min)
        self.lambda_max = float(lambda_max)
        self.warmup_steps = int(warmup_steps)
        self.eps = float(eps)

        self.sl_use_running_ref = bool(sl_use_running_ref)
        self.sl_ref_beta = float(sl_ref_beta)
        self.sl_excess_gate = float(sl_excess_gate)
        self.sl_delta_mode = str(sl_delta_mode).lower()

        # controller state
        self.register_buffer("l_ema", torch.tensor(0.0))
        self.register_buffer("sigma_ema", torch.tensor(0.0))
        self.register_buffer("sigma_ref", torch.tensor(0.0))
        self.register_buffer("warmup_count", torch.tensor(0))
        # telemetry
        self.register_buffer("last_lambda", torch.tensor(0.0))
        self.register_buffer("last_delta", torch.tensor(0.0))

        self._initialized = False

    @staticmethod
    def _huber(delta: torch.Tensor, thresh: float) -> torch.Tensor:
        # quadratic for |delta|<=thresh, linear otherwise; continuous at thresh
        thresh = max(float(thresh), 1e-12)
        absd = delta.abs()
        quad = 0.5 * (delta * delta) / thresh
        lin = absd - 0.5 * thresh
        return torch.where(absd <= thresh, quad, lin)

    @torch.no_grad()
    def _init_if_needed(self, loss_value: torch.Tensor):
        if not self._initialized:
            v = loss_value.detach().to(self.l_ema.device)
            self.l_ema.fill_(v)
            self.sigma_ema.zero_()
            self.sigma_ref.zero_()
            self.warmup_count.zero_()
            self.last_lambda.zero_()
            self.last_delta.zero_()
            self._initialized = True

    def forward(self, base_loss: torch.Tensor) -> torch.Tensor:
        assert base_loss.dim() == 0, "AdaptiveStableLoss expects a scalar batch loss"
        self._init_if_needed(base_loss)

        # deviation w.r.t EMA baseline (no grad through baseline)
        l_prev = self.l_ema.detach()
        delta_t = base_loss - l_prev

        # Huber threshold: absolute vs fraction of volatility
        if self.sl_delta_mode == "frac":
            scale_src = self.sigma_ref if self.sigma_ref.item() > 0.0 else self.sigma_ema
            delta = max(float(scale_src.item()) * self.delta_frac, 1e-6)
        else:
            delta = max(self.delta_frac, 1e-6)

        huber_t = self._huber(delta_t, delta)

        # update volatility stats (no grad)
        with torch.no_grad():
            self.sigma_ema.mul_(1.0 - self.beta).add_(self.beta * delta_t.detach().abs())
            self.warmup_count.add_(1)

            # one-time latch of sigma_ref after warmup
            if int(self.warmup_count.item()) == self.warmup_steps and self.sigma_ref.item() == 0.0:
                self.sigma_ref.copy_(torch.clamp(self.sigma_ema, min=self.eps))

            # optional running reference after latch
            if self.sl_use_running_ref and self.sigma_ref.item() > 0.0:
                self.sigma_ref.mul_(1.0 - self.sl_ref_beta).add_(self.sl_ref_beta * self.sigma_ema)

        # adaptive gain
        denom = self.sigma_ref if self.sigma_ref.item() > 0.0 else self.sigma_ema
        if self.sl_excess_gate > 1.0 and self.sigma_ref.item() > 0.0:
            active = (self.sigma_ema > self.sl_excess_gate * denom).float()
            gain = (self.sigma_ema / (denom + self.eps)) * active
        else:
            gain = self.sigma_ema / (denom + self.eps)

        lam_t = torch.clamp(self.lambda_base * gain, self.lambda_min, self.lambda_max)

        sl_t = lam_t * huber_t

        # update baseline (no grad) + telemetry
        with torch.no_grad():
            self.l_ema.mul_(1.0 - self.alpha).add_(self.alpha * base_loss.detach())
            self.last_lambda.copy_(lam_t)
            self.last_delta.copy_(delta_t.detach())

        return sl_t


# =========================
# Adaptive Label-aware VPL
# =========================
class AdaptiveLabelVariancePenalty(nn.Module):
    """
    Statistic selectable via `stat`:
      - 'vector': per-class variance of full logit vector; mean across classes
      - 'true'  : per-class variance of the true-class logit; mean across classes
    Controller: EMA(v_batch) with warmup latch to v_ref; adaptive lambda in [lambda_min, lambda_max]
    """
    def __init__(
        self,
        lambda_base: float = 1.0,
        scale: float = 1.0,
        alpha: float = 0.10,
        warmup_steps: int = 100,
        lambda_min: float = 0.0,
        lambda_max: float = 2.0,
        use_entropy: bool = False,
        entropy_scale: float = 0.5,
        eps: float = 1e-8,
        stat: str = "vector",
    ):
        super().__init__()
        self.lambda_base = float(lambda_base)
        self.scale = float(scale)
        self.alpha = float(alpha)
        self.warmup_steps = int(warmup_steps)
        self.lambda_min = float(lambda_min)
        self.lambda_max = float(lambda_max)
        self.use_entropy = bool(use_entropy)
        self.entropy_scale = float(entropy_scale)
        self.eps = float(eps)
        self.stat = str(stat).lower()
        if self.stat not in ("vector", "true"):
            raise ValueError(f"stat must be 'vector' or 'true', got {stat}")

        # controller state
        self.register_buffer("v_ema", torch.tensor(0.0))
        self.register_buffer("v_ref", torch.tensor(0.0))
        self.register_buffer("warmup_count", torch.tensor(0))
        # telemetry
        self.register_buffer("last_lambda", torch.tensor(0.0))
        self.register_buffer("last_vbatch", torch.tensor(0.0))

        self._initialized = False

    @torch.no_grad()
    def reset_state(self):
        self.v_ema.zero_()
        self.v_ref.zero_()
        self.warmup_count.zero_()
        self.last_lambda.zero_()
        self.last_vbatch.zero_()
        self._initialized = False

    @torch.no_grad()
    def _init_if_needed(self, v0: torch.Tensor):
        if not self._initialized:
            self.v_ema.fill_(v0)
            self.v_ref.zero_()
            self.warmup_count.zero_()
            self.last_lambda.zero_()
            self.last_vbatch.copy_(v0)
            self._initialized = True

    def _v_batch_vector(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        B, C = logits.shape
        vals = []
        for c in range(C):
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if idx.numel() >= 2:
                class_logits = logits.index_select(0, idx)         # [Nc, C]
                class_var = class_logits.var(dim=0, unbiased=False).mean()
                vals.append(class_var)
        return torch.stack(vals).mean() if vals else logits.new_tensor(0.0)

    def _v_batch_true(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        B, C = logits.shape
        vals = []
        for c in range(C):
            idx = (labels == c).nonzero(as_tuple=True)[0]
            if idx.numel() >= 2:
                class_true = logits.index_select(0, idx)[:, c]     # [Nc]
                class_var = class_true.var(unbiased=False)
                vals.append(class_var)
        return torch.stack(vals).mean() if vals else logits.new_tensor(0.0)

    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        assert logits.dim() == 2, "logits must be [B, C]"

        # 1) batch statistic
        v_batch = (
            self._v_batch_vector(logits, labels)
            if self.stat == "vector"
            else self._v_batch_true(logits, labels)
        )
        self._init_if_needed(v_batch.detach())

        # 2) controller update
        with torch.no_grad():
            self.v_ema.mul_(1.0 - self.alpha).add_(self.alpha * v_batch.detach())
            self.warmup_count.add_(1)
            if int(self.warmup_count.item()) >= self.warmup_steps and self.v_ref.item() == 0.0:
                self.v_ref.copy_(torch.clamp(self.v_ema, min=self.eps))
            self.last_vbatch.copy_(v_batch.detach())

        denom = self.v_ref if self.v_ref.item() > 0.0 else self.v_ema
        lam = self.lambda_base * (self.v_ema / (denom + self.eps))

        # 3) optional entropy modulation
        if self.use_entropy:
            with torch.no_grad():
                p = torch.softmax(logits, dim=1)
                ent = -(p * (p.clamp_min(self.eps).log())).sum(dim=1).mean()
                ent_norm = ent / max(math.log(max(logits.shape[1], 2)), self.eps)
            lam = lam * (1.0 + self.entropy_scale * ent_norm)

        lam = torch.clamp(lam, self.lambda_min, self.lambda_max)
        self.last_lambda.copy_(lam)

        # 4) final penalty
        return self.scale * lam * v_batch


# =========================
# ShuffleNetV2 blocks
# =========================
class ShuffleBlock(nn.Module):
    def __init__(self, groups=2):
        super().__init__()
        self.groups = groups

    def forward(self, x):
        N, C, H, W = x.size()
        g = self.groups
        return x.view(N, g, C // g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)


class SplitBlock(nn.Module):
    def __init__(self, ratio):
        super().__init__()
        self.ratio = ratio

    def forward(self, x):
        c = int(x.size(1) * self.ratio)
        return x[:, :c, :, :], x[:, c:, :, :]


class BasicBlock(nn.Module):
    def __init__(self, in_channels, split_ratio=0.5):
        super().__init__()
        self.split = SplitBlock(split_ratio)
        in_channels = int(in_channels * split_ratio)
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(in_channels)
        self.shuffle = ShuffleBlock()

    def forward(self, x):
        x1, x2 = self.split(x)
        out = F.relu(self.bn1(self.conv1(x2)))
        out = self.bn2(self.conv2(out))
        out = F.relu(self.bn3(self.conv3(out)))
        out = torch.cat([x1, out], 1)
        out = self.shuffle(out)
        return out


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        mid_channels = out_channels // 2

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2,
                               padding=1, groups=in_channels, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)

        self.conv3 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(mid_channels)
        self.conv4 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=2,
                               padding=1, groups=mid_channels, bias=False)
        self.bn4 = nn.BatchNorm2d(mid_channels)
        self.conv5 = nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False)
        self.bn5 = nn.BatchNorm2d(mid_channels)

        self.shuffle = ShuffleBlock()

    def forward(self, x):
        out1 = self.bn1(self.conv1(x))
        out1 = F.relu(self.bn2(self.conv2(out1)))

        out2 = F.relu(self.bn3(self.conv3(x)))
        out2 = self.bn4(self.conv4(out2))
        out2 = F.relu(self.bn5(self.conv5(out2)))

        out = torch.cat([out1, out2], 1)
        out = self.shuffle(out)
        return out


# =========================
# Backbone + factories
# =========================
configs = {
    0.5: {'out_channels': (48, 96, 192, 1024), 'num_blocks': (3, 7, 3)},
    1.0: {'out_channels': (116, 232, 464, 1024), 'num_blocks': (3, 7, 3)},
    1.5: {'out_channels': (176, 352, 704, 1024), 'num_blocks': (3, 7, 3)},
    2.0: {'out_channels': (224, 488, 976, 2048), 'num_blocks': (3, 7, 3)},
}


class ShuffleNetV2(nn.Module):
    def __init__(
        self,
        net_size,
        num_classes=10,
        init_strategy='he',

        # legacy knobs kept for API compatibility
        stable_weight=0.1,
        vpl_weight_decay=0.1,
        vpl_weight=0.1,

        # Adaptive SL defaults (overridable via configure_adaptive_stable)
        sl_alpha=0.10,
        sl_beta=0.10,
        sl_delta=0.20,           # interpreted as delta_frac
        sl_lambda_base=None,     # if None/0 -> derive from stable_weight (or 0.5)
        sl_lambda_min=0.0,
        sl_lambda_max=2.0,
        sl_warmup_steps=200,
        sl_eps=1e-8,
        sl_use_running_ref=False,
        sl_ref_beta=0.01,
        sl_excess_gate=1.0,
        sl_delta_mode="abs",

        # VPL
        vpl_stat="vector",
    ):
        super().__init__()
        oc = configs[net_size]['out_channels']
        nb = configs[net_size]['num_blocks']

        self.conv1 = nn.Conv2d(3, 24, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(24)
        self.in_channels = 24

        self.layer1 = self._make_layer(oc[0], nb[0])
        self.layer2 = self._make_layer(oc[1], nb[1])
        self.layer3 = self._make_layer(oc[2], nb[2])

        self.conv2 = nn.Conv2d(oc[2], oc[3], kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(oc[3])
        self.linear = nn.Linear(oc[3], num_classes)

        # --- Adaptive Stable Loss ---
        if sl_lambda_base is None or (isinstance(sl_lambda_base, float) and sl_lambda_base == 0.0):
            sl_lambda_base = stable_weight if (stable_weight is not None and stable_weight != 0.0) else 0.50
        self.adaptive_stable_loss = AdaptiveStableLoss(
            alpha=sl_alpha, beta=sl_beta, delta_frac=sl_delta,
            lambda_base=sl_lambda_base, lambda_min=sl_lambda_min, lambda_max=sl_lambda_max,
            warmup_steps=sl_warmup_steps, eps=sl_eps,
            sl_use_running_ref=sl_use_running_ref, sl_ref_beta=sl_ref_beta,
            sl_excess_gate=sl_excess_gate, sl_delta_mode=sl_delta_mode,
        )
        self.stable_mod = self.adaptive_stable_loss  # alias for trainer

        # --- Adaptive VPL ---
        self.variance_penalty = AdaptiveLabelVariancePenalty(
            lambda_base=1.0,
            scale=float(vpl_weight_decay),
            alpha=0.10,
            warmup_steps=100,
            lambda_min=0.0,
            lambda_max=2.0,
            use_entropy=False,
            entropy_scale=0.5,
            eps=1e-8,
            stat=vpl_stat,
        )
        self.vpl_weight = float(vpl_weight)

        self._initialize_weights(init_strategy)

    def _make_layer(self, out_channels, num_blocks):
        layers = [DownBlock(self.in_channels, out_channels)]
        self.in_channels = out_channels
        for _ in range(num_blocks):
            layers.append(BasicBlock(out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.adaptive_avg_pool2d(out, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    def _initialize_weights(self, init_strategy):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if init_strategy == 'he':
                    init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif init_strategy == 'xavier':
                    init.xavier_normal_(m.weight)
                elif init_strategy == 'custom_uniform':
                    init.uniform_(m.weight, -0.0089, 0.0089)
                elif init_strategy == 'custom_xavier':
                    init.xavier_normal_(m.weight)
                    m.weight.data.clamp_(-0.0089, 0.0089)
                elif init_strategy == 'custom_kaiming':
                    init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    m.weight.data.clamp_(-0.0089, 0.0089)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, 0, 0.01)
                init.constant_(m.bias, 0)

    # runtime reconfiguration (trainer can call this)
    def configure_adaptive_stable(
        self,
        alpha=None, beta=None, delta_frac=None,
        lambda_base=None, lambda_min=None, lambda_max=None,
        warmup_steps=None, use_running_ref=None, eps=None,
        sl_ref_beta=None, sl_excess_gate=None, sl_delta_mode=None,
    ):
        mod = self.adaptive_stable_loss
        if alpha is not None:        mod.alpha = float(alpha)
        if beta  is not None:        mod.beta  = float(beta)
        if eps   is not None:        mod.eps   = float(eps)
        if lambda_base is not None:  mod.lambda_base = float(lambda_base)
        if lambda_min  is not None:  mod.lambda_min  = float(lambda_min)
        if lambda_max  is not None:  mod.lambda_max  = float(lambda_max)
        if warmup_steps is not None: mod.warmup_steps = int(warmup_steps)
        if delta_frac is not None:   mod.delta_frac = float(delta_frac)
        if use_running_ref is not None: mod.sl_use_running_ref = bool(use_running_ref)
        if sl_ref_beta is not None:      mod.sl_ref_beta = float(sl_ref_beta)
        if sl_excess_gate is not None:   mod.sl_excess_gate = float(sl_excess_gate)
        if sl_delta_mode is not None:    mod.sl_delta_mode = str(sl_delta_mode).lower()
        return self


# Factories expected by the trainer
def shufflenet05(flags=None, **kwargs):
    old_state = model_utils.set_rng_state(flags)
    model = ShuffleNetV2(net_size=0.5, **kwargs)
    model_utils.restore_rng_state(old_state)
    return model

def shufflenet1(flags=None, **kwargs):
    old_state = model_utils.set_rng_state(flags)
    model = ShuffleNetV2(net_size=1.0, **kwargs)
    model_utils.restore_rng_state(old_state)
    return model

def shufflenet15(flags=None, **kwargs):
    old_state = model_utils.set_rng_state(flags)
    model = ShuffleNetV2(net_size=1.5, **kwargs)
    model_utils.restore_rng_state(old_state)
    return model

def shufflenet2(flags=None, **kwargs):
    old_state = model_utils.set_rng_state(flags)
    model = ShuffleNetV2(net_size=2.0, **kwargs)
    model_utils.restore_rng_state(old_state)
    return model
