import numpy as np
import tensorflow as tf
from datasets.NASBench301 import NasBench301Dataset

def inverse_from_acc(model: tf.keras.Model, num_sample_z: int, x_dim: int, z_dim: int, to_inv_acc,
                     noise_std=0., version=2):
    batch_size = int(tf.shape(to_inv_acc)[0])
    try:
        num_nvp = model.num_nvp
    except:
        num_nvp = 1

    y = tf.repeat(to_inv_acc, num_sample_z, axis=0)  # (batch_size * num_sample_z, 1)
    noise = tf.random.normal(tf.shape(y), stddev=0.003)
    y = y + noise
    # z (batch_size * num_sample_z, z_dim)
    y = tf.concat([tf.random.normal((batch_size * num_sample_z, z_dim), stddev=1), y], axis=-1)  # (num_sample_z, z_dim + 1)

    rev_latent = model.inverse(y)  # (num_sample_z, num_nvp, latent_dim)
    if version == 1:
        raise NotImplementedError
        # rev_latent = rev_latent[:, :x_dim]
    elif version == 2:
        rev_latent = tf.reshape(rev_latent,
                                (batch_size * num_nvp, model.num_nodes, -1))  # (batch_size, num_sample_z, latent_dim)
    else:
        raise ValueError('version')

    _, adj, ops_cls, _ = model.decode(rev_latent + tf.random.normal(tf.shape(rev_latent),
                                                                          stddev=noise_std))  # (batch_size, num_sample_z, 8, 8), (batch_size, num_sample_z, 8, 7)
    ops_cls = tf.reshape(ops_cls,
                         (batch_size * num_nvp, num_sample_z, -1, model.num_ops))  # (batch_size, num_sample_z, 8, 7)
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size, 1, 8 * 7)

    adj = tf.reshape(adj, (
    batch_size * num_nvp, num_sample_z, model.num_nodes, model.num_nodes))  # (batch_size, num_sample_z, 8, 8)
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # (batch_size, 8, 8)

    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]
    rev_latent_flattened = tf.reshape(rev_latent, (rev_latent.shape[0], -1))

    return ops_idx_list, adj_list, rev_latent_flattened

def eval_query_target(model: tf.keras.Model, x_dim: int, z_dim: int, query_amount=10, noise_scale=0.0,
                    version=2, target=0.949, hash_version=0):
    # Eval query 1.0
    found_arch_list = {}
    to_inv_acc = float(target)
    to_inv = tf.repeat(tf.reshape(tf.constant(to_inv_acc), [-1, 1]), query_amount, axis=0)
    # to_inv += noise_scale * tf.random.normal(tf.shape(to_inv))
    ops_idx_lis, adj_list, rev_latent = inverse_from_acc(model, num_sample_z=1, x_dim=x_dim, z_dim=z_dim,
                                             noise_std=noise_scale, to_inv_acc=to_inv, version=version)
    invalid = 0
    for ops_idx, adj, encoding in zip(ops_idx_lis, adj_list, rev_latent):
        try:
            A_darts = to_darts_graph(adj[:11, :11], darts_groups)
            A_fixed = fix_darts_adj(A_darts,
                                    input_nodes=[0,1],  # (D0, D1)
                                    output_node=6)      # (D6)
            m_11_normal = from_darts_graph_sequential(A_fixed, darts_groups)
            A_darts = to_darts_graph(adj[:11, :11], darts_groups)
            A_fixed = fix_darts_adj(A_darts,
                                    input_nodes=[0,1],  # (D0, D1)
                                    output_node=6)      # (D6)
            m_11_reduced = from_darts_graph_sequential(A_fixed, darts_groups)
            adj = np.zeros((22, 22), dtype=float)
            adj[:11,:11] = m_11_normal
            adj[11:,11:] = m_11_reduced
            acc = NasBench301Dataset.get_prediction(adj, ops_idx)
            hash_number = NasBench301Dataset.get_hash_by_darts_cell(adj, ops_idx, hash_version)
            ops = np.eye(model.num_ops)[ops_idx]
            found_arch_list[hash_number]={'x': np.array(ops).astype(np.float32), 
                                          'a': np.array(adj).astype(np.float32), 
                                          'y': np.array([acc]).astype(np.float32),
                                          'latent': encoding}
        except:
            invalid += 1
    # print("invalid amount:", invalid, len(found_arch_list), query_amount)

    return found_arch_list

def eval_from_lat(model: tf.keras.Model, rev_latent, batch_size, num_nodes, latent_dim, hash_version=0):
    num_samples = rev_latent.shape[0]
    rev_latent = tf.reshape(rev_latent, (batch_size, num_nodes, latent_dim))

    try:
        num_nvp = model.num_nvp
    except AttributeError:
        num_nvp = 1

    # Decode the latent representation
    print("rev latent", tf.shape(rev_latent))
    _, adj, ops_cls, _ = model.decode(rev_latent)  # Adjust shapes as per your model's decode method

    # Process ops_cls
    ops_cls = tf.reshape(ops_cls, (batch_size * num_nvp, 1, -1, model.num_ops))  # Adjust as needed
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size * num_nvp, ..., model.num_ops)

    # Process adj
    adj = tf.reshape(adj, (batch_size * num_nvp, 1, model.num_nodes, model.num_nodes))
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # Binarize adjacency matrices

    # Extract operation indices and adjacency lists
    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]

    found_arch_list = {}
    invalid = 0
    idx = 0
    record = []

    for ops_idx, adj, encoding in zip(ops_idx_list, adj_list, rev_latent):
        try:
            A_darts = to_darts_graph(adj[:11, :11], darts_groups)
            A_fixed = fix_darts_adj(A_darts,
                                    input_nodes=[0,1],  # (D0, D1)
                                    output_node=6)      # (D6)
            m_11_normal = from_darts_graph_sequential(A_fixed, darts_groups)
            A_darts = to_darts_graph(adj[:11, :11], darts_groups)
            A_fixed = fix_darts_adj(A_darts,
                                    input_nodes=[0,1],  # (D0, D1)
                                    output_node=6)      # (D6)
            m_11_reduced = from_darts_graph_sequential(A_fixed, darts_groups)
            adj = np.zeros((22, 22), dtype=float)
            adj[:11,:11] = m_11_normal
            adj[11:,11:] = m_11_reduced

            acc = NasBench301Dataset.get_prediction(adj, ops_idx)
            # print(acc)
            hash_number = NasBench301Dataset.get_hash_by_darts_cell(adj, ops_idx, hash_version)

            ops = np.eye(model.num_ops)[ops_idx]

            found_arch_list[hash_number] = {
                'x': np.array(ops).astype(np.float32),
                'a': np.array(adj).astype(np.float32),
                'y': np.array([acc]).astype(np.float32),
                'latent': encoding,
                'origin_index': idx
            }
            record.append(idx)
        except:
            invalid += 1
        idx += 1

    print(f"Invalid architectures: {invalid}/{len(found_arch_list)}/{num_samples}")

    return found_arch_list, record

def eval_from_lat_v2(model: tf.keras.Model, rev_latent, batch_size, num_nodes, latent_dim, hash_version=0):
    num_samples = rev_latent.shape[0]
    rev_latent = tf.reshape(rev_latent, (batch_size, num_nodes, latent_dim))

    try:
        num_nvp = model.num_nvp
    except AttributeError:
        num_nvp = 1

    # Decode the latent representation
    print("rev latent", tf.shape(rev_latent))
    _, adj, ops_cls, _ = model.decode(rev_latent)  # Adjust shapes as per your model's decode method

    # Process ops_cls
    ops_cls = tf.reshape(ops_cls, (batch_size * num_nvp, 1, -1, model.num_ops))  # Adjust as needed
    ops_vote = tf.reduce_sum(ops_cls, axis=1).numpy()  # (batch_size * num_nvp, ..., model.num_ops)

    # Process adj
    adj = tf.reshape(adj, (batch_size * num_nvp, 1, model.num_nodes, model.num_nodes))
    adj = tf.where(tf.reduce_mean(adj, axis=1) >= 0.5, x=1., y=0.).numpy()  # Binarize adjacency matrices

    # Extract operation indices and adjacency lists
    ops_idx_list = [np.argmax(i, axis=-1).tolist() for i in ops_vote]
    adj_list = [i for i in adj]

    found_arch_list = []
    invalid = 0
    idx = 0
    record = []

    for ops_idx, adj, encoding in zip(ops_idx_list, adj_list, rev_latent):
        try:
            A_darts = to_darts_graph(adj[:11, :11], darts_groups)
            A_fixed = fix_darts_adj(A_darts,
                                    input_nodes=[0,1],  # (D0, D1)
                                    output_node=6)      # (D6)
            m_11_normal = from_darts_graph_sequential(A_fixed, darts_groups)
            A_darts = to_darts_graph(adj[:11, :11], darts_groups)
            A_fixed = fix_darts_adj(A_darts,
                                    input_nodes=[0,1],  # (D0, D1)
                                    output_node=6)      # (D6)
            m_11_reduced = from_darts_graph_sequential(A_fixed, darts_groups)
            adj = np.zeros((22, 22), dtype=float)
            adj[:11,:11] = m_11_normal
            adj[11:,11:] = m_11_reduced

            acc = NasBench301Dataset.get_prediction(adj, ops_idx)

            hash_number = NasBench301Dataset.get_hash_by_darts_cell(adj, ops_idx, hash_version)

            ops = np.eye(model.num_ops)[ops_idx]

            found_arch_list.append({
                'x': np.array(ops).astype(np.float32),
                'a': np.array(adj).astype(np.float32),
                'y': np.array([acc]).astype(np.float32),
                'latent': encoding,
                'origin_index': idx,
                'hash': hash_number
            })
            record.append(idx)
        except:
            invalid += 1
        idx += 1

    print(f"Invalid architectures: {invalid}/{len(found_arch_list)}/{num_samples}")

    return found_arch_list, record

import numpy as np
import random
from typing import List

darts_groups = [
    [0],      
    [1],      
    [2,3],   
    [4,5],   
    [6,7],    
    [8,9],    
    [10]     
]

def to_darts_graph(adj_11: np.ndarray,
                   groups: List[List[int]]) -> np.ndarray:
    D = len(groups)  
    A_D = np.zeros((D, D), dtype=int)

    for i in range(D):
        for j in range(D):
            if i == j:
                continue
            group_i = groups[i]
            group_j = groups[j]
            sub_matrix = adj_11[np.ix_(group_i, group_j)]
            if np.any(sub_matrix == 1):
                A_D[i, j] = 1
    return A_D


def fix_darts_adj(A_D: np.ndarray,
                  input_nodes: List[int],
                  output_node: int,
                  seed: int = None) -> np.ndarray:
    if seed is not None:
        random.seed(seed)

    A_fix = A_D.copy()
    D = A_fix.shape[0]

    for inp in input_nodes:
        A_fix[:, inp] = 0

    for inp in input_nodes:
        A_fix[inp, output_node] = 0

    all_nodes = set(range(D))
    mid_nodes = all_nodes - set(input_nodes) - {output_node}  

    for m in mid_nodes:
        A_fix[m+1:, m] = 0
        current_in = [x for x in range(m) if A_fix[x, m] == 1]
        if len(current_in) > 2:
            keep = sorted(current_in)[:2]
            for x in current_in:
                if x not in keep:
                    A_fix[x, m] = 0
        elif len(current_in) < 2:
            need = 2 - len(current_in)
            candidates = [c for c in range(m) if A_fix[c, m] == 0]
            if len(candidates) < need:
                pass
            chosen = random.sample(candidates, need)
            for x in chosen:
                A_fix[x, m] = 1

    for m in [2, 3, 4, 5]:
        A_fix[m, output_node] = 1

    return A_fix

def from_darts_graph_sequential(A_D: np.ndarray,
                                groups: List[List[int]]) -> np.ndarray:
    D = len(groups)   # 7
    N = sum(len(g) for g in groups)  # 11
    A_11 = np.zeros((N, N), dtype=int)

    for j in range(2, 6):
        cnt = 0
        for i in range(j):
            if A_D[i, j] == 1:
                gi = groups[i]
                gj = groups[j]
                len_i = len(gi)
                for k in range(len_i):
                    A_11[gi[k], gj[cnt]] = 1
                cnt += 1
                if cnt == 2:
                    break
    
    A_11[2:10, 10] = 1

    return A_11