import random
import numpy as np
import torch
from torch.utils.data import Dataset
from tb101_api.api import TransNASBenchAPI

ops = { '1':'none',
      '2':'downsample',
      '3':'double channels',
      '4':'both'
}

op_list={'input':0,
         'none':1,
         'downsample':2,
         'double channels':3,
         'both':4,
         'output':5}

st = '64-432332-basic'

def trans101_2_nb101(arch_str,ops=ops):
    micro = arch_str.split('-')[1]
    arch_list = [ops[node] for node in micro]
    operation = ['input']
    operation.extend(arch_list)
    operation.append('output')
    #print(operation)
    adj = np.zeros((8, 8), dtype=np.uint8)
    for i in range(len(micro)+1):
        adj[i][i+1] =1
    arch = {'adj': adj,
            'ops': operation
            }
    return arch

class Trans101DatasetMacro(Dataset):
    def __init__(self,split=50,candidate_ops=6,data_type='train',task='class_scene',query_val=None):
        self.nasbench = TransNASBenchAPI("tb101_api/api/api_home/transnas-bench_v10141024.pth")
        self.sample_range = list()
        self.candidate_ops = candidate_ops
        if data_type == 'train':
              self.sample_range = random.sample(range(3256),int(split))
        elif data_type == 'test':
                if split == 'all':
                    self.sample_range = range(3256)
                else:
                    self.sample_range = random.sample(range(3256), int(split))
        elif data_type == 'eval':
              self.sample_range = [val for val in query_val]
        else:
              pass

        self.data_type = data_type
        self.task = task
        if self.task == 'class_scene':
            self.val_mean , self.val_std = 0.529250,0.020794
            self.test_mean , self.test_std = 0.617579,0.018162
        elif self.task == 'class_object':
            self.val_mean , self.val_std = 0.442406,0.013759
            self.test_mean , self.test_std = 0.510327,0.011456
        elif self.task == 'jigsaw':
            self.val_mean , self.val_std = 0.934337,0.021795
            self.test_mean , self.test_std = 0.932362,0.278064
        elif self.task == 'segmentsemantic':
            self.val_mean, self.val_std = 0.2447,0.0207
            self.test_mean, self.test_std = 0.234019,0.063810
        elif self.task == 'room_layout':
            self.val_mean, self.val_std = 0.65, 0.03
            self.test_mean, self.test_std = 0.704089, 0.142574
        elif self.task == 'normal':
            self.val_mean, self.val_std = 0.57, 0.02
            self.test_mean, self.test_std = 0.542159, 0.065221
        elif self.task == 'autoencoder':
            self.val_mean, self.val_std = 0.52, 0.08
            self.test_mean, self.test_std = 0.479246, 0.103602
        else:
            pass

    def __len__(self):
        return len(self.sample_range)

    def normalize(self, num):
        if self.data_type == 'train':
            return (num - self.val_mean) / self.val_std
        elif self.data_type == 'test' or self.data_type == 'eval':
            return (num - self.test_mean) / self.test_std
        else:
            pass

    def denormalize(self, num):
        if self.data_type == 'train':
            return num * self.val_std + self.val_mean
        elif self.data_type == 'test' or self.data_type == 'eval':
            return num * self.test_std + self.test_mean
        else:
            pass

    def __getitem__(self, index):
        index = self.sample_range[index]
        arch = self.nasbench.index2arch(index)
        if self.task in ['class_scene','class_object','jigsaw']:
            val_acc = self.nasbench.get_single_metric(arch,self.task,'valid_top1')
            test_acc = self.nasbench.get_single_metric(arch,self.task,'test_top1')
        elif self.task == 'segmentsemantic':
            val_acc = self.nasbench.get_single_metric(arch, self.task, 'valid_mIoU')
            test_acc = self.nasbench.get_single_metric(arch, self.task, 'test_mIoU')
        elif self.task in ['normal','autoencoder']:
            val_acc = self.nasbench.get_single_metric(arch, self.task, 'valid_ssim')
            test_acc = self.nasbench.get_single_metric(arch, self.task, 'test_ssim')
        elif self.task == 'room_layout':
            val_acc = self.nasbench.get_single_metric(arch, self.task, 'valid_neg_loss')*(-1)*100
            test_acc = self.nasbench.get_single_metric(arch, self.task, 'test_neg_loss')*(-1)*100
        else:
            pass
        new_arch = trans101_2_nb101(arch)
        adjacency = new_arch['adj']
        operations = new_arch['ops']
        operation = [op_list[i] for i in operations]
        operation = np.array(operation)
        ops_onehot = np.array([[i == k for i in range(self.candidate_ops)]
                               for k in operation], dtype=np.float32)
        n = len(operations)
        if n<8:
            for i in range(8-n):
                ops_onehot=np.append(ops_onehot,np.array([[0,0,0,0,0,0]]),axis=0)
        edge_index = []
        for i in range(adjacency.shape[0]):
            idx_list = np.where(adjacency[i])[0].tolist()
            for j in idx_list:
                edge_index.append([i, j])
        if np.sum(edge_index) == 0:
            edge_index = []
            for i in range(adjacency.shape[0]):
                for j in range(adjacency.shape[0] - 1, i, -1):
                    edge_index.append([i, j])
        edge_num = len(edge_index)
        pad_num = 7 - edge_num
        if pad_num > 0:
            edge_index = np.pad(np.array(edge_index), ((0, pad_num), (0, 0)), 'constant', constant_values=(1, 1))
        edge_index = torch.tensor(edge_index, dtype=torch.int64)
        edge_index = edge_index.transpose(1, 0)


        result = {
            "num_vertices": 8,
            "edge_num": edge_num,
            "adjacency": np.array(adjacency, dtype=np.float32),
            "operations": ops_onehot,
            "n_val_acc": torch.tensor(self.normalize(val_acc / 100), dtype=torch.float32),
            "n_test_acc": torch.tensor(self.normalize(test_acc / 100), dtype=torch.float32),
            "val_acc": val_acc / 100,
            "test_acc": test_acc / 100,
            "edge_index_list": edge_index,
        }
        return result
