# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import model_utils

# ---------- Adaptive Stable Loss (A/B/C) ----------
class AdaptiveStableLoss(nn.Module):
    def __init__(self,
                 alpha=0.10, beta=0.10, delta_frac=0.20,
                 lambda_base=0.50, lambda_min=0.0, lambda_max=2.0,
                 warmup_steps=200, eps=1e-8,
                 sl_use_running_ref=False, sl_ref_beta=0.01,
                 sl_excess_gate=1.0, sl_delta_mode="abs"):
        super().__init__()
        self.alpha=float(alpha); self.beta=float(beta); self.delta_frac=float(delta_frac)
        self.lambda_base=float(lambda_base); self.lambda_min=float(lambda_min); self.lambda_max=float(lambda_max)
        self.warmup_steps=int(warmup_steps); self.eps=float(eps)
        self.sl_use_running_ref=bool(sl_use_running_ref); self.sl_ref_beta=float(sl_ref_beta)
        self.sl_excess_gate=float(sl_excess_gate); self.sl_delta_mode=str(sl_delta_mode).lower()
        self.register_buffer("l_ema",torch.tensor(0.0))
        self.register_buffer("sigma_ema",torch.tensor(0.0))
        self.register_buffer("sigma_ref",torch.tensor(0.0))
        self.register_buffer("warmup_count",torch.tensor(0))
        self.register_buffer("last_lambda",torch.tensor(0.0))
        self.register_buffer("last_delta",torch.tensor(0.0))
        self._initialized=False

    @staticmethod
    def _huber(delta,thresh):
        absd=delta.abs()
        quad=0.5*(delta*delta)/max(thresh,1e-12)
        lin=absd-0.5*thresh
        return torch.where(absd<=thresh,quad,lin)

    @torch.no_grad()
    def _init_if_needed(self,loss_value):
        if not self._initialized:
            v=loss_value.detach().to(self.l_ema.device)
            self.l_ema.fill_(v); self.sigma_ema.zero_(); self.sigma_ref.zero_()
            self.warmup_count.zero_(); self.last_lambda.zero_(); self.last_delta.zero_()
            self._initialized=True

    def forward(self,base_loss:torch.Tensor)->torch.Tensor:
        assert base_loss.dim()==0
        self._init_if_needed(base_loss)
        l_prev=self.l_ema.detach()
        delta_t=base_loss-l_prev

        if self.sl_delta_mode=="frac":
            scale_src=self.sigma_ref if self.sigma_ref.item()>0.0 else self.sigma_ema
            delta=max(float(scale_src.item())*self.delta_frac,1e-6)
        else:
            delta=max(self.delta_frac,1e-6)
        huber_t=self._huber(delta_t,delta)

        with torch.no_grad():
            self.sigma_ema.mul_(1.0-self.beta).add_(self.beta*delta_t.detach().abs())
            self.warmup_count.add_(1)
            if int(self.warmup_count.item())==self.warmup_steps and self.sigma_ref.item()==0.0:
                self.sigma_ref.copy_(torch.clamp(self.sigma_ema,min=self.eps))
            if self.sl_use_running_ref and self.sigma_ref.item()>0.0:
                self.sigma_ref.mul_(1.0-self.sl_ref_beta).add_(self.sl_ref_beta*self.sigma_ema)

        denom=self.sigma_ref if self.sigma_ref.item()>0.0 else self.sigma_ema
        if self.sl_excess_gate>1.0 and self.sigma_ref.item()>0.0:
            active=(self.sigma_ema>self.sl_excess_gate*denom).float()
            gain=(self.sigma_ema/(denom+self.eps))*active
        else:
            gain=self.sigma_ema/(denom+self.eps)

        lam_t=torch.clamp(self.lambda_base*gain,self.lambda_min,self.lambda_max)
        sl_t=lam_t*huber_t

        with torch.no_grad():
            self.l_ema.mul_(1.0-self.alpha).add_(self.alpha*base_loss.detach())
            self.last_lambda.copy_(lam_t); self.last_delta.copy_(delta_t.detach())
        return sl_t


# ---------- Adaptive Label-aware VPL ----------
class AdaptiveLabelVariancePenalty(nn.Module):
    def __init__(self, lambda_base=1.0, scale=1.0, alpha=0.10, warmup_steps=100,
                 lambda_min=0.0, lambda_max=2.0, use_entropy=False, entropy_scale=0.5,
                 eps=1e-8, stat="vector"):
        super().__init__()
        self.lambda_base=float(lambda_base); self.scale=float(scale)
        self.alpha=float(alpha); self.warmup_steps=int(warmup_steps)
        self.lambda_min=float(lambda_min); self.lambda_max=float(lambda_max)
        self.use_entropy=bool(use_entropy); self.entropy_scale=float(entropy_scale)
        self.eps=float(eps); self.stat=str(stat).lower()
        assert self.stat in ("vector","true")
        self.register_buffer("v_ema",torch.tensor(0.0))
        self.register_buffer("v_ref",torch.tensor(0.0))
        self.register_buffer("warmup_count",torch.tensor(0))
        self.register_buffer("last_lambda",torch.tensor(0.0))
        self.register_buffer("last_vbatch",torch.tensor(0.0))
        self._initialized=False

    @torch.no_grad()
    def _init_if_needed(self,v0):
        if not self._initialized:
            self.v_ema.fill_(v0); self.v_ref.zero_(); self.warmup_count.zero_()
            self.last_lambda.zero_(); self.last_vbatch.copy_(v0); self._initialized=True

    def _vb_vector(self,logits,labels):
        B,C=logits.shape; vals=[]
        for c in range(C):
            idx=(labels==c).nonzero(as_tuple=True)[0]
            if idx.numel()>=2:
                vals.append(logits.index_select(0,idx).var(dim=0,unbiased=False).mean())
        return torch.stack(vals).mean() if vals else logits.new_tensor(0.0)

    def _vb_true(self,logits,labels):
        B,C=logits.shape; vals=[]
        for c in range(C):
            idx=(labels==c).nonzero(as_tuple=True)[0]
            if idx.numel()>=2:
                vals.append(logits.index_select(0,idx)[:,c].var(unbiased=False))
        return torch.stack(vals).mean() if vals else logits.newTensor(0.0)

    def forward(self,logits,labels):
        v_batch=self._vb_vector(logits,labels) if self.stat=="vector" else self._vb_true(logits,labels)
        self._init_if_needed(v_batch.detach())
        with torch.no_grad():
            self.v_ema.mul_(1.0-self.alpha).add_(self.alpha*v_batch.detach())
            self.warmup_count.add_(1)
            if int(self.warmup_count.item())>=self.warmup_steps and self.v_ref.item()==0.0:
                self.v_ref.copy_(torch.clamp(self.v_ema,min=self.eps))
            self.last_vbatch.copy_(v_batch.detach())
        denom=self.v_ref if self.v_ref.item()>0.0 else self.v_ema
        lam=self.lambda_base*(self.v_ema/(denom+self.eps))
        if self.use_entropy:
            with torch.no_grad():
                p=torch.softmax(logits,dim=1)
                ent=-(p*p.clamp_min(self.eps).log()).sum(dim=1).mean()
                ent_norm=ent/max(math.log(max(logits.shape[1],2)),self.eps)
            lam=lam*(1.0+self.entropy_scale*ent_norm)
        lam=torch.clamp(lam,self.lambda_min,self.lambda_max)
        self.last_lambda.copy_(lam)
        return self.scale*lam*v_batch


# ---------- MobileNetV2 blocks ----------
def _make_divisible(v, divisor=8, min_value=None):
    if min_value is None: min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v: new_v += divisor
    return new_v

class ConvBNReLU(nn.Sequential):
    def __init__(self,in_planes,out_planes,kernel_size=3,stride=1,groups=1):
        padding=(kernel_size-1)//2
        super().__init__(
            nn.Conv2d(in_planes,out_planes,kernel_size,stride,padding,groups=groups,bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )

class InvertedResidual(nn.Module):
    def __init__(self,inp,oup,stride,expand_ratio):
        super().__init__()
        assert stride in [1,2]
        hidden_dim=int(round(inp*expand_ratio))
        self.use_res_connect=(stride==1 and inp==oup)
        layers=[]
        if expand_ratio!=1: layers.append(ConvBNReLU(inp,hidden_dim,kernel_size=1))
        layers.extend([
            ConvBNReLU(hidden_dim,hidden_dim,stride=stride,groups=hidden_dim),
            nn.Conv2d(hidden_dim,oup,1,1,0,bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv=nn.Sequential(*layers)
    def forward(self,x): return x+self.conv(x) if self.use_res_connect else self.conv(x)


# ---------- Backbone ----------
class MobileNetV2(nn.Module):
    def __init__(self, width_mult=1.0, num_classes=10, init_strategy='he',
                 stable_weight=0.1, vpl_weight_decay=0.1, vpl_weight=0.1,
                 sl_alpha=0.10, sl_beta=0.10, sl_delta=0.20,
                 sl_lambda_base=None, sl_lambda_min=0.0, sl_lambda_max=2.0,
                 sl_warmup_steps=200, sl_eps=1e-8,
                 sl_use_running_ref=False, sl_ref_beta=0.01,
                 sl_excess_gate=1.0, sl_delta_mode="abs",
                 vpl_stat="vector"):
        super().__init__()
        input_channel = _make_divisible(32 * width_mult, 8)
        last_channel  = _make_divisible(1280 * max(1.0, width_mult), 8)

        features = []
        features.append(ConvBNReLU(3, input_channel, stride=1))  # CIFAR: stride 1

        inv = [
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]
        for t,c,n,s in inv:
            out_ch = _make_divisible(c * width_mult, 8)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(InvertedResidual(input_channel, out_ch, stride, expand_ratio=t))
                input_channel = out_ch
        features.append(ConvBNReLU(input_channel, last_channel, kernel_size=1))
        self.features = nn.Sequential(*features)
        self.classifier = nn.Linear(last_channel, num_classes)

        if sl_lambda_base is None or (isinstance(sl_lambda_base,float) and sl_lambda_base==0.0):
            sl_lambda_base = stable_weight if (stable_weight is not None and stable_weight != 0.0) else 0.50
        self.adaptive_stable_loss = AdaptiveStableLoss(
            alpha=sl_alpha, beta=sl_beta, delta_frac=sl_delta,
            lambda_base=sl_lambda_base, lambda_min=sl_lambda_min, lambda_max=sl_lambda_max,
            warmup_steps=sl_warmup_steps, eps=sl_eps,
            sl_use_running_ref=sl_use_running_ref, sl_ref_beta=sl_ref_beta,
            sl_excess_gate=sl_excess_gate, sl_delta_mode=sl_delta_mode
        )
        self.stable_mod=self.adaptive_stable_loss

        self.variance_penalty=AdaptiveLabelVariancePenalty(lambda_base=1.0, scale=float(vpl_weight_decay),
                                                           alpha=0.10, warmup_steps=100, lambda_min=0.0, lambda_max=2.0,
                                                           use_entropy=False, entropy_scale=0.5, eps=1e-8, stat=vpl_stat)
        self.vpl_weight=float(vpl_weight)
        self._initialize_weights(init_strategy)

    def forward(self,x):
        x=self.features(x)
        x=F.adaptive_avg_pool2d(x,1)
        x=torch.flatten(x,1)
        return self.classifier(x)

    def _initialize_weights(self,init_strategy):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                if init_strategy=='he':
                    init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
                elif init_strategy=='xavier':
                    init.xavier_normal_(m.weight)
                elif init_strategy=='custom_uniform':
                    init.uniform_(m.weight,-0.0089,0.0089)
                elif init_strategy=='custom_xavier':
                    init.xavier_normal_(m.weight); m.weight.data.clamp_(-0.0089,0.0089)
                elif init_strategy=='custom_kaiming':
                    init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu'); m.weight.data.clamp_(-0.0089,0.0089)
            elif isinstance(m,nn.BatchNorm2d):
                init.constant_(m.weight,1); init.constant_(m.bias,0)
            elif isinstance(m,nn.Linear):
                init.normal_(m.weight,0,0.01); init.constant_(m.bias,0)

    def configure_adaptive_stable(self, alpha=None,beta=None,delta_frac=None,
                                  lambda_base=None,lambda_min=None,lambda_max=None,
                                  warmup_steps=None,use_running_ref=None,eps=None,
                                  sl_ref_beta=None,sl_excess_gate=None,sl_delta_mode=None):
        mod=self.adaptive_stable_loss
        if alpha is not None: mod.alpha=float(alpha)
        if beta  is not None: mod.beta=float(beta)
        if eps   is not None: mod.eps=float(eps)
        if lambda_base is not None: mod.lambda_base=float(lambda_base)
        if lambda_min  is not None: mod.lambda_min=float(lambda_min)
        if lambda_max  is not None: mod.lambda_max=float(lambda_max)
        if warmup_steps is not None: mod.warmup_steps=int(warmup_steps)
        if delta_frac is not None: mod.delta_frac=float(delta_frac)
        if use_running_ref is not None: mod.sl_use_running_ref=bool(use_running_ref)
        if sl_ref_beta is not None: mod.sl_ref_beta=float(sl_ref_beta)
        if sl_excess_gate is not None: mod.sl_excess_gate=float(sl_excess_gate)
        if sl_delta_mode is not None: mod.sl_delta_mode=str(sl_delta_mode).lower()
        return self


# Factories
def mobilenet05(flags=None, **kwargs):
    old=model_utils.set_rng_state(flags); m=MobileNetV2(width_mult=0.5, **kwargs); model_utils.restore_rng_state(old); return m
def mobilenet1(flags=None, **kwargs):
    old=model_utils.set_rng_state(flags); m=MobileNetV2(width_mult=1.0, **kwargs); model_utils.restore_rng_state(old); return m
def mobilenet14(flags=None, **kwargs):
    old=model_utils.set_rng_state(flags); m=MobileNetV2(width_mult=1.4, **kwargs); model_utils.restore_rng_state(old); return m
