import logging
from collections import Counter
import numpy as np
import torch
import torch.nn.functional as F

import utils
from genotypes_201 import Genotype, PRIMITIVES

EDGE_INDEX_NB201 = [
    (0, 1),
    (0, 2), (1, 2),
    (0, 3), (1, 3), (2, 3),
]

EDGE2U_LOG = [0, 0, 1, 0, 1, 2]

@torch.no_grad()
def remove_op(normal_weights: torch.Tensor, op):
    if isinstance(op, str):
        toks = op.split('_')
        assert toks[0] == 'normal', f"unexpected op tag: {op}"
        eid = int(toks[1]); opid = int(toks[-1])
    else:
        eid, opid = op
    w = normal_weights.clone()
    w[eid, opid] = 0.0
    row = w[eid]
    s = row.sum()
    if s.item() <= 1e-12:
        w[eid] = torch.full_like(row, 1.0 / row.numel())
    else:
        w[eid] = row / s
    return w

def _fmt_op_line(i, tag, name, u, eid, opid, val, thr, s, flag):
    return (f"[IQR-{flag}] s={s:02d} | idx={i:02d} | {tag} | name={name:<13} | "
            f"u={u} | edge={eid} | opid={opid} | val={val:+.6f} | thr={thr:+.6f}")

def _op_meta(ops, PRIMITIVES, idx):
    tag = str(ops[idx])
    toks = tag.split("_")
    eid = int(toks[1]); opid = int(toks[-1])
    name = PRIMITIVES[opid] if 0 <= opid < len(PRIMITIVES) else f"op{opid}"
    u = EDGE2U_LOG[eid]
    return tag, name, u, eid, opid

def _name_list_from_indices(ops, PRIMITIVES, idx_list, max_show=16):
    labels = []
    for i in idx_list[:max_show]:
        tag, name, _u, eid, _opid = _op_meta(ops, PRIMITIVES, i)
        labels.append(f"{name}@e{eid}")
    if len(idx_list) > max_show:
        labels.append(f"...(+{len(idx_list) - max_show})")
    return ", ".join(labels) if labels else "-"

def _count_by_name(ops, PRIMITIVES, idx_list):
    c = Counter()
    for i in idx_list:
        _tag, name, _u, _eid, _opid = _op_meta(ops, PRIMITIVES, i)
        c[name] += 1
    return ", ".join([f"{k}×{v}" for k, v in c.most_common()]) if c else "-"


@torch.no_grad()
def compute_value(valid_queue, model, ops, num_samples, beta=0.3, G=5):

    assert isinstance(num_samples, int) and num_samples >= 1, "num_samples must be an integer >= 1"


    E, O = 6, len(PRIMITIVES)
    N = len(ops)
    assert N == E * O, f"N={N} but E*O={E*O}; ops size must match 6 edges × {O} ops."


    idx_mat = np.full((E, O), -1, dtype=int)
    for i, op in enumerate(ops):
        toks = str(op).split('_')
        eid = int(toks[1]); opid = int(toks[-1])
        idx_mat[eid, opid] = i
    assert (idx_mat >= 0).all(), "idx_mat has -1: some (e,o) not found in ops!"

    op2idx = {op: idx for idx, op in enumerate(ops)}
    permutations = [np.random.permutation(ops) for _ in range(num_samples)]

    eval_values_std = np.zeros((N, num_samples), dtype=np.float32)

    anomalies_per_sample = []
    pos_trigger = {op: 0 for op in ops}
    neg_trigger = {op: 0 for op in ops}

    valid_iter = iter(valid_queue)
    MAX_LINES = 50

    for s in range(num_samples):
        try:
            x, y = next(valid_iter)
        except StopIteration:
            valid_iter = iter(valid_queue)
            x, y = next(valid_iter)

        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)
        model.eval()


        std_base, = utils.accuracy(model(x), y)
        std_base = std_base.item()
        print(f"[ComputeValue] std_acc {std_base:.4f}")

        nw = model.get_projected_weights()

        for op in permutations[s]:
            real_idx = op2idx[op]
            nw_drop = remove_op(nw, op)
            cur_std, = utils.accuracy(model(x, weights_dict=nw_drop), y)
            cur_std = cur_std.item()
            eval_values_std[real_idx, s] = std_base - cur_std

        eps = 1e-8
        mat = np.empty((E, O), dtype=np.float32)
        for e in range(E):
            for o in range(O):
                mat[e, o] = eval_values_std[idx_mat[e, o], s]


        pos_idxs_all = []
        neg_idxs_all = []

        for e in range(E):
            row = mat[e]
            mu = row.mean()
            sd = row.std()
            row_z = (row - mu) / (sd + eps)
            mat[e] = row_z


            q1, q3 = np.percentile(row_z, [25, 75])
            iqr = q3 - q1
            thr_pos = q3 + 2.0 * iqr
            thr_neg = q1 - 1.5 * iqr

            pos_o = np.where(row_z > thr_pos)[0].tolist()
            neg_o = np.where(row_z < thr_neg)[0].tolist()

            for o in pos_o:
                pos_idxs_all.append(idx_mat[e, o])
            for o in neg_o:
                neg_idxs_all.append(idx_mat[e, o])

            print(f"[IQR-EDGE] s={s:02d} e={e} thr_pos={thr_pos:+.6f} thr_neg={thr_neg:+.6f} "
                  f"| POS={len(pos_o)} | NEG={len(neg_o)}")

        for e in range(E):
            for o in range(O):
                eval_values_std[idx_mat[e, o], s] = mat[e, o]

        anomalies_per_sample.append({
            's': s,
            'thr_pos': None,
            'thr_neg': None,
            'pos': pos_idxs_all,
            'neg': neg_idxs_all,
        })

        pos_names_line = _name_list_from_indices(ops, PRIMITIVES, pos_idxs_all, max_show=16)
        neg_names_line = _name_list_from_indices(ops, PRIMITIVES, neg_idxs_all, max_show=16)
        pos_bucket_line = _count_by_name(ops, PRIMITIVES, pos_idxs_all)
        neg_bucket_line = _count_by_name(ops, PRIMITIVES, neg_idxs_all)
        msg_head = (f"[IQR-SUMMARY] s={s:02d} | POS[{len(pos_idxs_all)}] {pos_names_line} "
                    f"| NEG[{len(neg_idxs_all)}] {neg_names_line}")
        msg_bucket = (f"[IQR-BUCKET]  s={s:02d} | POS: {pos_bucket_line} | NEG: {neg_bucket_line}")
        print(msg_head);  logging.info(msg_head)
        print(msg_bucket); logging.info(msg_bucket)

        for k, idx in enumerate(pos_idxs_all):
            if k < MAX_LINES:
                tag, name, u, eid, opid = _op_meta(ops, PRIMITIVES, idx)
                val = eval_values_std[idx, s]
                msg = _fmt_op_line(idx, tag, name, u, eid, opid, val, 0.0, s, "POS")
                print(msg); logging.info(msg)
            pos_trigger[ops[idx]] += 1
        if len(pos_idxs_all) > MAX_LINES:
            logging.info(f"[IQR-POS] s={s:02d} ... (+{len(pos_idxs_all)-MAX_LINES} more)")

        for k, idx in enumerate(neg_idxs_all):
            if k < MAX_LINES:
                tag, name, u, eid, opid = _op_meta(ops, PRIMITIVES, idx)
                val = eval_values_std[idx, s]
                msg = _fmt_op_line(idx, tag, name, u, eid, opid, val, 0.0, s, "NEG")
                print(msg); logging.info(msg)
            neg_trigger[ops[idx]] += 1
        if len(neg_idxs_all) > MAX_LINES:
            logging.info(f"[IQR-NEG] s={s:02d} ... (+{len(neg_idxs_all)-MAX_LINES} more)")

    normal_values_mean = np.zeros((E, O), dtype=np.float32)
    for e in range(E):
        for o in range(O):
            idx = idx_mat[e, o]
            mean_std = float(eval_values_std[idx].mean())
            normal_values_mean[e, o] = mean_std

    S = float(max(1, num_samples))
    pos_vec = np.array([pos_trigger.get(op, 0) for op in ops], dtype=np.float32)
    neg_vec = np.array([neg_trigger.get(op, 0) for op in ops], dtype=np.float32)
    v_vec = (pos_vec - neg_vec) / S  # [-1,1]

    G_eff = int(max(1, min(G, num_samples)))
    groups = np.array_split(np.arange(num_samples), G_eff)
    m_vec = np.zeros_like(v_vec)
    for i in range(N):
        gmeans = [float(eval_values_std[i, idxs].mean()) for idxs in groups if len(idxs) > 0]
        m_vec[i] = 0.0 if len(gmeans) == 0 else float(np.median(gmeans))

    score_vec = (1.0 - float(beta)) * m_vec + float(beta) * v_vec

    rose_matrix = np.zeros((E, O), dtype=np.float32)
    mom_matrix  = np.zeros((E, O), dtype=np.float32)
    v_matrix    = np.zeros((E, O), dtype=np.float32)
    for e in range(E):
        for o in range(O):
            idx = idx_mat[e, o]
            rose_matrix[e, o] = float(score_vec[idx])
            mom_matrix[e, o]  = float(m_vec[idx])
            v_matrix[e, o]    = float(v_vec[idx])

    logging.info("====== MoM / IQR / ROSE ======")
    for i, op in enumerate(ops):
        toks = str(op).split('_'); eid = int(toks[1]); opid = int(toks[-1])
        logging.info(
            f"[DIAG] op={op:12s} | mean={normal_values_mean[eid, opid]:+.4f} "
            f"| MoM={mom_matrix[eid, opid]:+.4f} | IQRfreq={v_matrix[eid, opid]:+.4f} "
            f"| ROSE={(rose_matrix[eid, opid]):+.4f}"
        )

    normal_values = rose_matrix
    return normal_values


def get_best_op(alpha, index_list, epoch=None, banned_ops=('none',)):
    if hasattr(alpha, "detach"):
        alpha = alpha.detach().cpu().numpy()
    else:
        alpha = np.asarray(alpha)
    allowed_mask = np.array([op not in banned_ops for op in PRIMITIVES], dtype=bool)
    allowed_idx = np.where(allowed_mask)[0]
    best_tuple = None
    for e in index_list:
        scores = np.asarray(alpha[e], dtype=np.float64)
        scores = np.where(np.isfinite(scores), scores, -np.inf)
        local_scores = scores[allowed_idx]
        if np.all(~np.isfinite(local_scores)):
            local_op = int(np.nanargmax(scores))
            local_val = float(scores[local_op])
        else:
            j = int(np.argmax(local_scores))
            local_op = int(allowed_idx[j])
            local_val = float(local_scores[j])
        cand = (local_val, e, local_op)
        if (best_tuple is None) or (cand[0] > best_tuple[0]):
            best_tuple = cand
    if best_tuple is None:
        e = index_list[0] if index_list else 0
        local_op = int(np.nanargmax(alpha[e]))
        return local_op, e
    _, edge_idx, op_idx = best_tuple
    return op_idx, edge_idx

def ranking(alpha_normal: torch.Tensor, epoch: int):
    probs = F.softmax(alpha_normal, dim=-1).detach().cpu().numpy()
    chosen = np.argmax(probs, axis=1).tolist()
    edge2u = {0: 0, 1: 0, 2: 1, 3: 0, 4: 1, 5: 2}
    normal = []
    for e in range(6):
        op_name = PRIMITIVES[chosen[e]]
        u = edge2u[e]
        normal.append((op_name, u))
    groups = {
        0: [0],
        1: [1, 2],
        2: [1, 2],
        3: [3, 4, 5],
        4: [3, 4, 5],
        5: [3, 4, 5],
    }
    alpha_np = alpha_normal.detach().cpu().numpy()
    for pos in range(6):
        if normal[pos][0] == 'none':
            k = groups[pos]
            best_op, best_edge = get_best_op(alpha_np, k, epoch, banned_ops=('none',))
            normal[pos] = (PRIMITIVES[best_op], edge2u[best_edge])
    genotype = Genotype(normal=normal, normal_concat=[1, 2, 3])
    return genotype

@torch.no_grad()
def update_alpha(eval_values_np, prev_value, step_size=0.2, momentum=0.8):
    v = torch.from_numpy(eval_values_np).float().cuda()
    v = (v - v.mean()) / (v.std() + 1e-8)
    new_buf = momentum * prev_value + (1 - momentum) * v
    delta_alpha = step_size * new_buf
    return delta_alpha, new_buf
