import torch
from torch.utils.data import Dataset
import numpy as np
from darts.genotypes import Genotype, PRIMITIVES


def sample_darts_arch(available_ops):
    geno = []
    for _ in range(2):
        cell = []
        for i in range(4):
            ops_normal = np.random.choice(available_ops, 2)
            nodes_in_normal = sorted(np.random.choice(range(i+2), 2, replace=False))
            cell.extend([(ops_normal[0], nodes_in_normal[0]), (ops_normal[1], nodes_in_normal[1])])
        geno.append(cell)
    genotype = Genotype(normal=geno[0], normal_concat=[2, 3, 4, 5], reduce=geno[1], reduce_concat=[2, 3, 4, 5])
    return genotype


def darts_to_nasbench101(genotype):
    arch = []
    for arch_list, concat in [(genotype.normal, genotype.normal_concat), (genotype.reduce, genotype.reduce_concat)]:
        num_ops = len(arch_list) + 3
        adj = np.zeros((num_ops, num_ops), dtype=np.uint8)
        ops = ['input1', 'input2', 'output']
        node_lists = [[0], [1], [2, 3], [4, 5], [6, 7], [8, 9], [10]]
        for node in arch_list:
            node_idx = len(ops) - 1
            adj[node_lists[node[1]], node_idx] = 1
            ops.insert(-1, node[0])
        adj[[x for c in concat for x in node_lists[c]], -1] = 1
        cell = {'adj': adj,
                'ops': ops,
                }
        arch.append(cell)
        break
    adj = np.zeros((num_ops * 2, num_ops * 2), dtype=np.uint8)
    adj[:num_ops, :num_ops] = arch[0]['adj']
    ops = arch[0]['ops']
    arch = {'adj': adj,
            'ops': ops}
    return arch


op_list={
    'input1':0,
    'input2':1,
    'output':2,
    'sep_conv_3x3':3,
    'sep_conv_5x5':4,
    'dil_conv_3x3':5,
    'dil_conv_5x5':6,
    'avg_pool_3x3':7,
    'max_pool_3x3':8,
    'skip_connect':9,
    'none':10
}





class Darts_dataset(Dataset):
    def __init__(self,split=0,datatype='train',source_arch='None',raw_genotype='None'):
        self.data_type=datatype
        self.source_arch=[]
        self.sample_range = split
        self.genotype=[]
        self.candidate_ops = 11
        if datatype=='train':
            self.source_arch=source_arch
            self.genotype=raw_genotype
        else:
            print("creating test arch...")
            for i in range(1,1+self.sample_range):
                genotype=sample_darts_arch(PRIMITIVES[1:])
                self.genotype.append(genotype)
                #print(genotype)
                arch=darts_to_nasbench101(genotype)
                self.source_arch.append(arch)
                if i%(self.sample_range/100)==0:
                    print("have done:",i)


    def __len__(self):
        if self.data_type=='train':
            return len(self.source_arch)
        else:
            return self.sample_range


    def __getitem__(self, index):
        adj=self.source_arch[index]['adj']
        ops=self.source_arch[index]['ops']

        operation = [op_list[i] for i in ops]
        operation = np.array(operation)
        ops_onehot = np.array([[i == k for i in range(self.candidate_ops)]
                               for k in operation], dtype=np.float32)

        s_genotype=self.genotype[index]
        val_acc=0
        test_acc=0
        if self.data_type=='train':
            val_acc=self.source_arch[index]['val_acc']
            test_acc=self.source_arch[index]['test_acc']
        features = np.array([op_list[i] for i in ops])
        edge_index = []
        for i in range(adj.shape[0]):
            idx_list = np.where(adj[i])[0].tolist()
            for j in idx_list:
                edge_index.append([i, j])
        if np.sum(edge_index) == 0:
            edge_index = []
            for i in range(adj.shape[0]):
                for j in range(adj.shape[0] - 1, i, -1):
                    edge_index.append([i, j])
        edge_num = len(edge_index)
        pad_num = 42 - edge_num
        if pad_num > 0:
            edge_index = np.pad(np.array(edge_index), ((0, pad_num), (0, 0)), 'constant', constant_values=(0, 0))
        edge_index = torch.tensor(edge_index, dtype=torch.int64)
        edge_index = edge_index.transpose(1, 0)




        result={
            'adj':adj,
            'operations':ops_onehot,
            'num_vertices': 22,
            'edge_num':edge_num,
            'features': torch.from_numpy(features).long(),
            'edge_index_list':edge_index,
            'val_acc':float(val_acc),
            'test_acc':float(test_acc),
            'genotype':s_genotype
        }
        return result


if __name__ == '__main__':
    darts_dataset=Darts_dataset(split=20,datatype='test')
    print(darts_dataset.__getitem__(10))

