import logging

import torch
import utils
import model
import numpy as np
import torchattacks
from genotypes import Genotype, PRIMITIVES
def remove_op(normal_weights, reduce_weights, robust_weights, op):
    selected_cell = str(op.split('_')[0])
    selected_eid = int(op.split('_')[1])
    opid = int(op.split('_')[-1])
    proj_mask = torch.ones_like(normal_weights[selected_eid])
    proj_mask[opid] = 0
    if selected_cell in ['normal']:
        normal_weights[selected_eid] = normal_weights[selected_eid] * proj_mask
    elif selected_cell in ['reduce']:
        reduce_weights[selected_eid] = reduce_weights[selected_eid] * proj_mask
    else:
        robust_weights[selected_eid] = robust_weights[selected_eid] * proj_mask

    return normal_weights, reduce_weights, robust_weights

def compute_value(valid_queue, model, ops, num_samples, beta=0.3, G=5):
    # ----------------------------------------------------------------

    op2idx = {op: idx for idx, op in enumerate(ops)}
    branches = ['normal', 'reduce', 'robust']
    pos_trigger = {b: {op: 0 for op in ops} for b in branches}
    neg_trigger = {b: {op: 0 for op in ops} for b in branches}


    permutations = [np.random.permutation(ops) for _ in range(num_samples)]



    N = len(ops)

    eval_values_std = np.zeros((N, num_samples))
    eval_values_adv = np.zeros((N, num_samples))
    pgd = torchattacks.PGD(model, eps=8/255, steps=7)


    for s in range(num_samples):
        input, target = next(iter(valid_queue))
        input, target = input.cuda(non_blocking=True), target.cuda(non_blocking=True)
        model.eval()



        std_base, = utils.accuracy(model(input), target)
        std_base = std_base.item()
        print(f"[ComputeValue] std_acc {std_base:.4f}")



        input_adv = pgd(input.clone().detach().requires_grad_(True), target)
        adv_base, = utils.accuracy(model(input_adv), target)
        adv_base = adv_base.item()
        print(f"[ComputeValue] adv_acc {adv_base:.4f}")




        nw = model.get_projected_weights('normal')
        rw = model.get_projected_weights('reduce')
        bw = model.get_projected_weights('robust')




        with torch.no_grad():
            for pos, op in enumerate(permutations[s]):

                real_idx = op2idx[op]





                nw, rw, bw = remove_op(nw, rw, bw, op)

                cur_std, = utils.accuracy(
                    model(input, weights_dict={'normal': nw, 'reduce': rw, 'robust': bw}, rm_key=True),
                    target
                )
                cur_std = cur_std.item()

                cur_adv, = utils.accuracy(
                    model(input_adv, weights_dict={'normal': nw, 'reduce': rw, 'robust': bw}, rm_key=True),
                    target
                )
                cur_adv = cur_adv.item()


                eval_values_std[real_idx, s] = std_base - cur_std
                print("deleta_std ", std_base - cur_std)
                eval_values_adv[real_idx, s] = adv_base - cur_adv
                print("delta__adv ", adv_base - cur_adv)

        eps = 1e-8
        col_std = eval_values_std[:, s]
        eval_values_std[:, s] = (col_std - col_std.mean()) / (col_std.std() + eps)
        col_adv = eval_values_adv[:, s]
        eval_values_adv[:, s] = (col_adv - col_adv.mean()) / (col_adv.std() + eps)

    E, O = model.num_edges, model.num_ops
    normal_values = np.zeros((E, O))
    reduce_values = np.zeros((E, O))
    robust_values = np.zeros((E, O))
    mixed_normal = eval_values_std
    mixed_reduce = 0.5 * eval_values_std + 0.5 * eval_values_adv
    anomalies_per_sample = []
    for s in range(num_samples):
        sample_anoms = {}
        for branch_name, data in [
            ('normal', mixed_normal),
            ('reduce', mixed_reduce),
            ('robust', eval_values_adv),
        ]:
            col = data[:, s]
            q1, q3 = np.percentile(col, [25, 75])
            iqr = q3 - q1
            thr_pos = q3 + 1.2 * iqr
            thr_neg = q1 - 2.0 * iqr

            pos_idxs = np.where(col > thr_pos)[0].tolist()
            neg_idxs = np.where(col < thr_neg)[0].tolist()
            sample_anoms[branch_name] = {'pos': pos_idxs, 'neg': neg_idxs}

            for idx in pos_idxs:
                op = ops[idx]
                logging.info(f"[{branch_name.upper()} IQR POS] "
                             f"op={op} | s={s}, thr={thr_pos:.4f}, val={col[idx]:.4f}")
                pos_trigger[branch_name][op] += 1
            for idx in neg_idxs:
                op = ops[idx]
                logging.info(f"[{branch_name.upper()} IQR NEG] "
                             f"op={op} | s={s}, thr={thr_neg:.4f}, val={col[idx]:.4f}")
                neg_trigger[branch_name][op] += 1

        anomalies_per_sample.append(sample_anoms)
    groups = np.array_split(np.arange(num_samples), G)

    # normal & reduce
    for idx, op in enumerate(ops):
        cell, eid_str, opid_str = op.split('_')
        eid, opid = int(eid_str), int(opid_str)
        if cell in ('normal', 'reduce'):
            arr = mixed_normal[idx] if cell == 'normal' else mixed_reduce[idx]
            grp_means = [arr[g].mean() for g in groups]
            mom = float(np.median(grp_means))
            f_pos = pos_trigger[cell][op]
            f_neg = neg_trigger[cell][op]
            outlier_score = (f_pos - f_neg) / num_samples
            logging.info(
                f"[{cell.upper()} SCORE] op={op:15s}  "
                f"MoM={mom:.4f}, outlier={outlier_score:.4f}  "
                f"final={(1 - beta) * mom + beta * outlier_score:.4f}"
            )
            final = (1 - beta) * mom + beta * outlier_score
            if cell == 'normal':
                normal_values[eid, opid] = final
            else:
                reduce_values[eid, opid] = final
    for idx, op in enumerate(ops):
        cell, eid_str, opid_str = op.split('_')
        if cell == 'robust':
            eid, opid = int(eid_str), int(opid_str)
            arr = eval_values_adv[idx]
            grp_means = [arr[g].mean() for g in groups]
            mom = float(np.median(grp_means))

            f_pos = pos_trigger['robust'][op]
            f_neg = neg_trigger['robust'][op]
            outlier_score = (f_pos - f_neg) / num_samples
            logging.info(
                f"[ROBUST SCORE] op={op:15s}  "
                f"MoM={mom:.4f}, outlier={outlier_score:.4f}  "
                f"final={(1 - beta) * mom + beta * outlier_score:.4f}"
            )
            robust_values[eid, opid] = (1 - beta) * mom + beta * outlier_score



    return normal_values, reduce_values, robust_values
def update_alpha(eval_values, prev_value, step_size=0.1, momentum=0.8):
    values = []
    for i in range(len(eval_values)):
        values.append(torch.from_numpy(eval_values[i]).cuda())

    inc = []
    for i in range(len(values)):
        mean = values[i].data.mean()
        std = values[i].data.std()
        eps = 1e-8

        values[i].data.add_(-mean).div_(std + eps)


        v = momentum * prev_value[i] + (1 - momentum) * values[i]
        inc.append(v)

    delta_alpha_normal = step_size * inc[0]
    delta_alpha_reduce = step_size * inc[1]
    delta_alpha_robust = step_size * inc[2]
    return [delta_alpha_normal, delta_alpha_reduce, delta_alpha_robust]
def get_best_op(alpha, index_list, epoch):
    operations = []


    for k in index_list:
        sorted_indices = np.argsort(alpha[k])
        top_two_ops = sorted_indices[-2:]
        for op in top_two_ops:
            operations.append((alpha[k][op], k, op))
    operations.sort(reverse=True, key=lambda x: x[0])
    banned = {'none', 'max_pool_3x3', 'avg_pool_3x3'}
    for val, edge, op in operations:
        if PRIMITIVES[op] not in banned:
            return op, edge
    _, edge, op = operations[0]
    return op, edge



def ranking(alpha_normal, alpha_reduce, alpha_robust, epoch):
    alpha_normal = alpha_normal.cpu().numpy()

    alpha_reduce = alpha_reduce.cpu().numpy()

    alpha_robust = alpha_robust.cpu().numpy()


    operation_count = {op: 0 for op in PRIMITIVES}


    selected = []
    for i in range(len(alpha_normal)):
        value = np.max(alpha_normal[i])
        opid = np.argmax(alpha_normal[i])
        selected.append([0, i, opid, value])

    for i in range(len(alpha_reduce)):
        value = np.max(alpha_reduce[i])
        opid = np.argmax(alpha_reduce[i])
        selected.append([1, i, opid, value])

    for i in range(len(alpha_robust)):
        value = np.max(alpha_robust[i])
        opid = np.argmax(alpha_robust[i])
        selected.append([2, i, opid, value])

    selected = np.array(selected)

    nodes_normal = [0, 0, 0, 0]
    nodes_reduce = [0, 0, 0, 0]
    nodes_robust = [0, 0, 0, 0]
    genotype = Genotype(normal=[['none', 0], ['none', 1], ['none', 0], ['none', 1], ['none', 0], ['none', 1], ['none', 0], ['none', 1]], normal_concat=[2, 3, 4, 5],
                        reduce=[['none', 0], ['none', 1], ['none', 0], ['none', 1], ['none', 0], ['none', 1], ['none', 0], ['none', 1]], reduce_concat=[2, 3, 4, 5],
                        robust=[['none', 0], ['none', 1], ['none', 0], ['none', 1], ['none', 0], ['none', 1], ['none', 0], ['none', 1]], robust_concat=[2, 3, 4, 5])

    while selected.size != 0:
        op = np.argmax(selected[:, 3])

        add_or_not = True


        if add_or_not == True:
            if selected[op][0] == 0:
                if int(selected[op][1]) >= 0 and int(selected[op][1]) <= 1:  # 节点0,1
                    genotype.normal[nodes_normal[0]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.normal[nodes_normal[0]][1] = int(selected[op][1])
                    nodes_normal[0] += 1

                elif int(selected[op][1]) >= 2 and int(selected[op][1]) <= 4 and nodes_normal[1] < 2:
                    genotype.normal[2 + nodes_normal[1]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.normal[2 + nodes_normal[1]][1] = int(selected[op][1]) - 2
                    nodes_normal[1] += 1

                elif int(selected[op][1]) >= 5 and int(selected[op][1]) <= 8 and nodes_normal[2] < 2:
                    genotype.normal[4 + nodes_normal[2]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.normal[4 + nodes_normal[2]][1] = int(selected[op][1]) - 5
                    nodes_normal[2] += 1

                elif int(selected[op][1]) >= 9 and int(selected[op][1]) <= 13 and nodes_normal[3] < 2:
                    genotype.normal[6 + nodes_normal[3]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.normal[6 + nodes_normal[3]][1] = int(selected[op][1]) - 9
                    nodes_normal[3] += 1

                operation_count[PRIMITIVES[int(selected[op][2])]] += 1

            elif selected[op][0] == 1:
                if int(selected[op][1]) >= 0 and int(selected[op][1]) <= 1:
                    genotype.reduce[nodes_reduce[0]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.reduce[nodes_reduce[0]][1] = int(selected[op][1])
                    nodes_reduce[0] += 1

                elif int(selected[op][1]) >= 2 and int(selected[op][1]) <= 4 and nodes_reduce[1] < 2:
                    genotype.reduce[2 + nodes_reduce[1]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.reduce[2 + nodes_reduce[1]][1] = int(selected[op][1]) - 2
                    nodes_reduce[1] += 1

                elif int(selected[op][1]) >= 5 and int(selected[op][1]) <= 8 and nodes_reduce[2] < 2:
                    genotype.reduce[4 + nodes_reduce[2]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.reduce[4 + nodes_reduce[2]][1] = int(selected[op][1]) - 5
                    nodes_reduce[2] += 1

                elif int(selected[op][1]) >= 9 and int(selected[op][1]) <= 13 and nodes_reduce[3] < 2:
                    genotype.reduce[6 + nodes_reduce[3]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.reduce[6 + nodes_reduce[3]][1] = int(selected[op][1]) - 9
                    nodes_reduce[3] += 1

                operation_count[PRIMITIVES[int(selected[op][2])]] += 1
            else:
                if int(selected[op][1]) >= 0 and int(selected[op][1]) <= 1:
                    genotype.robust[nodes_robust[0]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.robust[nodes_robust[0]][1] = int(selected[op][1])
                    nodes_robust[0] += 1

                elif int(selected[op][1]) >= 2 and int(selected[op][1]) <= 4 and nodes_robust[1] < 2:
                    genotype.robust[2 + nodes_robust[1]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.robust[2 + nodes_robust[1]][1] = int(selected[op][1]) - 2
                    nodes_robust[1] += 1

                elif int(selected[op][1]) >= 5 and int(selected[op][1]) <= 8 and nodes_robust[2] < 2:
                    genotype.robust[4 + nodes_robust[2]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.robust[4 + nodes_robust[2]][1] = int(selected[op][1]) - 5
                    nodes_robust[2] += 1

                elif int(selected[op][1]) >= 9 and int(selected[op][1]) <= 13 and nodes_robust[3] < 2:
                    genotype.robust[6 + nodes_robust[3]][0] = PRIMITIVES[int(selected[op][2])]
                    genotype.robust[6 + nodes_robust[3]][1] = int(selected[op][1]) - 9
                    nodes_robust[3] += 1

                operation_count[PRIMITIVES[int(selected[op][2])]] += 1

        selected = np.delete(selected, op, axis=0)

    for i in range(len(genotype.normal)):
        if genotype.normal[i][0] == 'none':

            if i in [0, 1]:
                k = [0, 1]
                node_index = {0:0, 1:1}
            elif i in [2, 3]:
                k = [2, 3, 4]
                node_index = {2:0, 3:1, 4:2}
            elif i in [4, 5]:
                k = [5, 6, 7, 8]
                node_index = {5:0, 6:1, 7:2, 8:3}
            else:
                k = [9, 10, 11, 12, 13]
                node_index = {9:0, 10:1, 11:2, 12:3, 13:4}
            best_op, best_edgenum = get_best_op(alpha_normal, k, epoch)
            genotype.normal[i][0] = PRIMITIVES[best_op]
            genotype.normal[i][1] = node_index[best_edgenum]

    for i in range(len(genotype.reduce)):
        if genotype.reduce[i][0] == 'none':

            if i in [0, 1]:
                k = [0, 1]
                node_index = {0: 0, 1: 1}
            elif i in [2, 3]:
                k = [2, 3, 4]
                node_index = {2: 0, 3: 1, 4: 2}
            elif i in [4, 5]:
                k = [5, 6, 7, 8]
                node_index = {5: 0, 6: 1, 7: 2, 8: 3}
            else:
                k = [9, 10, 11, 12, 13]
                node_index = {9: 0, 10: 1, 11: 2, 12: 3, 13: 4}
            best_op, best_edgenum = get_best_op(alpha_reduce, k, epoch)
            genotype.reduce[i][0] = PRIMITIVES[best_op]
            genotype.reduce[i][1] = node_index[best_edgenum]

    for i in range(len(genotype.robust)):
        if genotype.robust[i][0] == 'none':
            if i in [0, 1]:
                k = [0, 1]
                node_index = {0: 0, 1: 1}
            elif i in [2, 3]:
                k = [2, 3, 4]
                node_index = {2: 0, 3: 1, 4: 2}
            elif i in [4, 5]:
                k = [5, 6, 7, 8]
                node_index = {5: 0, 6: 1, 7: 2, 8: 3}
            else:
                k = [9, 10, 11, 12, 13]
                node_index = {9: 0, 10: 1, 11: 2, 12: 3, 13: 4}
            best_op, best_edgenum = get_best_op(alpha_robust, k, epoch)
            genotype.robust[i][0] = PRIMITIVES[best_op]
            genotype.robust[i][1] = node_index[best_edgenum]



    return genotype