import random
import numpy as np
import torch
from torch.utils.data import Dataset
from tb101_api.api import TransNASBenchAPI

ops={ '0':'none',
      '1':'skip_connect',
      '2':'nor_conv_1x1',
      '3':'nor_conv_3x3',
}

op_list={'input':0,
         'nor_conv_1x1':1,
         'nor_conv_3x3':2,
         'skip_connect':3,
         'none':4,
         'output':5}

def trans101_2_nb101(arch_str):
    micro = arch_str[-8]+arch_str[-6:-4]+arch_str[-3::]
    node1 = [ops[idx] for idx in arch_str[-8]]
    node2 = [ops[idx] for idx in arch_str[-6:-4]]
    node3 = [ops[idx] for idx in arch_str[-3::]]
    arch_list = []
    arch_list.append(node1)
    arch_list.append(node2)
    arch_list.append(node3)
    operation = ['input','output']
    node_lists=[[0]]
    adj = np.zeros((len(micro)+2,len(micro)+2),dtype=np.uint8)
    for node_t in arch_list:
        node_list=[]
        for i in range(len(node_t)):
            node_idx = len(operation) - 1
            adj[node_lists[i],node_idx] =1
            operation.insert(-1,node_t[i])
            node_list.append(node_idx)
        node_lists.append(node_list)
    adj[-(1+len(arch_list)):-1, -1] = 1
    arch = { 'adj':adj,
             'ops':operation
    }
    return arch

class Trans101Dataset(Dataset):
    def __init__(self,split=50,candidate_ops=6,data_type='train',task='class_scene',query_val=None):
        self.nasbench = TransNASBenchAPI("tb101_api/api/api_home/transnas-bench_v10141024.pth")
        self.sample_range = list()
        self.candidate_ops = candidate_ops
        if data_type == 'train':
            self.sample_range = random.sample(range(3256,len(self.nasbench)),int(split))
        elif data_type == 'test':
            if split == 'all':
                self.sample_range = range(3256,len(self.nasbench))
            else:
                self.sample_range = random.sample(range(3256, len(self.nasbench)), int(split))
        elif data_type == 'eval':
              self.sample_range = [val for val in query_val]
        else:
              pass

        self.data_type = data_type
        self.task = task
        if self.task == 'class_scene':
            self.val_mean , self.val_std = 0.452286,0.121533
            self.test_mean , self.test_std = 0.540011,0.132722
        elif self.task == 'class_object':
            self.val_mean , self.val_std = 0.397118,0.059229
            self.test_mean , self.test_std = 0.462211,0.067488
        elif self.task == 'jigsaw':
            self.val_mean , self.val_std = 0.765728,0.283390
            self.test_mean , self.test_std = 0.768727,0.278064
        elif self.task == 'segmentsemantic':
            self.val_mean, self.val_std = 0.199482,0.055198
            self.test_mean, self.test_std = 0.234019,0.063810
        elif self.task == 'room_layout':
            self.val_mean, self.val_std = 0.726471, 0.116026
            self.test_mean, self.test_std = 0.704089, 0.142574
        elif self.task == 'normal':
            self.val_mean, self.val_std = 0.518584, 0.058653
            self.test_mean, self.test_std = 0.542159, 0.065221
        elif self.task == 'autoencoder':
            self.val_mean, self.val_std = 0.458917, 0.100607
            self.test_mean, self.test_std = 0.479246, 0.103602
        else:
            pass

    def __len__(self):
        return len(self.sample_range)

    def normalize(self, num):
        if self.data_type == 'train':
            return (num - self.val_mean) / self.val_std
        elif self.data_type == 'test' or self.data_type == 'eval':
            return (num - self.test_mean) / self.test_std
        else:
            pass

    def denormalize(self, num):
        if self.data_type == 'train':
            return num * self.val_std + self.val_mean
        elif self.data_type == 'test' or self.data_type == 'eval':
            return num * self.test_std + self.test_mean
        else:
            pass

    def __getitem__(self, index):
        index = self.sample_range[index]
        arch = self.nasbench.index2arch(index)
        if self.task in ['class_scene','class_object','jigsaw']:
            val_acc = self.nasbench.get_single_metric(arch,self.task,'valid_top1')
            test_acc = self.nasbench.get_single_metric(arch,self.task,'test_top1')
        elif self.task == 'segmentsemantic':
            val_acc = self.nasbench.get_single_metric(arch, self.task, 'valid_mIoU')
            test_acc = self.nasbench.get_single_metric(arch, self.task, 'test_mIoU')
        elif self.task in ['normal','autoencoder']:
            val_acc = self.nasbench.get_single_metric(arch, self.task, 'valid_ssim')
            test_acc = self.nasbench.get_single_metric(arch, self.task, 'test_ssim')
        elif self.task == 'room_layout':
            val_acc = self.nasbench.get_single_metric(arch, self.task, 'valid_neg_loss')*(-1)*100
            test_acc = self.nasbench.get_single_metric(arch, self.task, 'test_neg_loss')*(-1)*100
        else:
            pass
        new_arch = trans101_2_nb101(arch)
        adjacency = new_arch['adj']
        operations = new_arch['ops']
        operation = [op_list[i] for i in operations]
        operation = np.array(operation)
        ops_onehot = np.array([[i == k for i in range(self.candidate_ops)]
                               for k in operation], dtype=np.float32)

        edge_index = []
        for i in range(adjacency.shape[0]):
            idx_list = np.where(adjacency[i])[0].tolist()
            for j in idx_list:
                edge_index.append([i, j])
        if np.sum(edge_index) == 0:
            edge_index = []
            for i in range(adjacency.shape[0]):
                for j in range(adjacency.shape[0] - 1, i, -1):
                    edge_index.append([i, j])

        edge_num = len(edge_index)
        edge_index = torch.tensor(edge_index, dtype=torch.int64)
        edge_index = edge_index.transpose(1, 0)

        result = {
            "arch_str":arch,
            "arch":new_arch,
            "num_vertices": len(operations),
            "edge_num": edge_num,
            "adjacency": np.array(adjacency, dtype=np.float32),
            "operations": ops_onehot,
            "features": torch.from_numpy(operation).long(),
            "n_val_acc": torch.tensor(self.normalize(val_acc / 100), dtype=torch.float32),
            "n_test_acc": torch.tensor(self.normalize(test_acc / 100), dtype=torch.float32),
            "val_acc": val_acc / 100,
            "test_acc": test_acc / 100,
            "edge_index_list": edge_index,
        }
        return result
