from flax import struct
import jax, jax.numpy as jnp, optax
from SafeDreamer.nets import ScoreNetRCMonotone, scorenet_lipschitz_row_project
from SafeDreamer.scorenet.losses import dominance_hinge, tie_band, slope_ratio_regularizer
from SafeDreamer.scorenet.pairs import select_pairs

@struct.dataclass
class ScoreNetState: params: any; opt_state: any; ema_params: any; tau: float
def ema_update(e,p,tau): return jax.tree_util.tree_map(lambda E,P: (1-tau)*E+tau*P, e, p)
def linear_sched(step, a, b, total): t=jnp.clip(step/max(1,total),0,1); return float((1-t)*a+t*b)

def init_scorenet(rng, cfg):
  net = ScoreNetRCMonotone(hidden=cfg.hidden, depth=cfg.depth, use_tanh=cfg.use_tanh,
                           use_pgr=cfg.use_pgr, pgr_skip_dim=cfg.pgr_skip_dim, pgr_gate_clip=cfg.pgr_gate_clip,
                           use_mic=cfg.use_mic, mic_use_tanh=cfg.mic_use_tanh)
  params = net.init(rng, jnp.zeros((2,1)), jnp.zeros((2,1)), jnp.zeros((2,cfg.feat_dim)))
  opt = optax.adam(cfg.lr)
  return net, ScoreNetState(params, opt.init(params), params, cfg.ema_tau), opt

def _loss(net, params, batch, cfg):
  r,c,feat = batch['r'], batch['c'], batch['feat']
  s_all = net.apply(params, r,c,feat); i,j,dom_ij,dom_ji,inc = batch['i'],batch['j'],batch['dom_ij'],batch['dom_ji'],batch['inc']
  s_i, s_j = s_all[i], s_all[j]
  loss_dom = 0.0
  if dom_ij.any(): loss_dom += jnp.mean(dominance_hinge(s_i[dom_ij], s_j[dom_ij], cfg.margin))
  if dom_ji.any(): loss_dom += jnp.mean(dominance_hinge(s_j[dom_ji], s_i[dom_ji], cfg.margin))
  loss_tie = jnp.mean(tie_band(s_i[inc], s_j[inc], cfg.tie_delta)) if inc.any() else 0.0
  N = r.shape[0]; M = max(1, int(N*cfg.ratio_subsample)); idx = jax.random.permutation(jax.random.PRNGKey(cfg.ratio_seed), jnp.arange(N))[:M]
  loss_ratio = jnp.mean(slope_ratio_regularizer(net.apply, params, r[idx], c[idx], jax.lax.stop_gradient(feat[idx]), cfg.rho_min, cfg.rho_max))
  total = cfg.w_dom*loss_dom + cfg.w_tie*loss_tie + cfg.w_ratio*loss_ratio
  return total, dict(loss_dom=loss_dom, loss_tie=loss_tie, loss_ratio=loss_ratio, loss_total=total)

@jax.jit
def _step(params, opt_state, batch, cfg, net, optimizer):
  (loss, logs), grads = jax.value_and_grad(_loss, has_aux=True)(net, params, batch, cfg)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss, logs

def train_scorenet_step(rng, state, net, optimizer, batch_real, batch_imag, vR_target, vC_target, cfg, step):
  def _prep(b):
    return dict(r=b['r'].reshape(-1,1), c=b['c'].reshape(-1,1), feat=b['feat'],
                gR=b.get('gR', vR_target(b)), gC=b.get('gC', vC_target(b)))
  real = _prep(batch_real); imag = _prep(batch_imag) if (batch_imag and batch_imag['r'].shape[0]>0) else None
  mix_t = linear_sched(step, 0.0, 1.0, cfg.real_imag_warmup_steps)
  if imag is None: mix = real
  else:
    kR = int(real['r'].shape[0]*(1-0.5*mix_t)); kI = int(imag['r'].shape[0]*(0.5*mix_t))
    def cut(x): return x[:max(0,min(x.shape[0], kR))], x
    r = jnp.concatenate([real['r'][:kR], imag['r'][:kI]]); c = jnp.concatenate([real['c'][:kR], imag['c'][:kI]])
    feat = jnp.concatenate([real['feat'][:kR], imag['feat'][:kI]])
    gR = jnp.concatenate([real['gR'][:kR], imag['gR'][:kI]]); gC = jnp.concatenate([real['gC'][:kR], imag['gC'][:kI]])
    mix = dict(r=r,c=c,feat=feat,gR=gR,gC=gC)
  rng, sub = jax.random.split(rng)
  eps_r = linear_sched(step, cfg.label_eps_r_init, cfg.label_eps_r_final, cfg.real_imag_warmup_steps)
  eps_c = linear_sched(step, cfg.label_eps_c_init, cfg.label_eps_c_final, cfg.real_imag_warmup_steps)
  i,j,dom_ij,dom_ji,inc = select_pairs(sub, mix['gR'], mix['gC'], cfg.pairs_per_sample, eps_r, eps_c, cfg.use_hard_mining)
  batch_pairs = dict(r=mix['r'], c=mix['c'], feat=mix['feat'], i=i, j=j, dom_ij=dom_ij, dom_ji=dom_ji, inc=inc)
  params, opt_state, loss, logs = _step(state.params, state.opt_state, batch_pairs, cfg, net, optimizer)
  if getattr(cfg, 'use_lipschitz', True): params = scorenet_lipschitz_row_project(params, L=cfg.lipschitz_L)
  ema = ema_update(state.ema_params, params, state.tau)
  state = ScoreNetState(params=params, opt_state=opt_state, ema_params=ema, tau=state.tau)
  logs = {**logs, 'eps_r': eps_r, 'eps_c': eps_c}
  return rng, state, logs
