import random, os, copy, time
import torch
import torch.nn as nn
from nasbench import api as NASBench101API
import numpy as np
from argparse import ArgumentParser
from scipy import stats
from nb_dataset_101 import Nb101Dataset
from nb_dataset_201 import Nb201Dataset
from tb_dataset_101 import Trans101Dataset
from tb101_api.api import TransNASBenchAPI
from nas_201_api import NASBench201API
import torch.optim as optim

from scipy.stats import kendalltau,weightedtau
from model import PlainGCN
from torch.utils.data import DataLoader
import nni
from functools import cmp_to_key
import  collections
from utils import (AverageMeter,to_cuda,
                   set_seed,BPRLoss,list_mle,
                   top_k_best_rank,wpair2,rel_at_1)

class NAS(object):
    def __init__(self, N, search_space, dataset, flops_limit,params,epochs,task,api_loc=None,device='cpu',seed=None):
        self.N = N
        self.search_space = search_space
        self.dataset = dataset
        self.flops_limit = flops_limit
        self.device = device
        self.seed = seed
        self.visited = []
        self.epochs = epochs
        self.task = task
        if self.search_space == '101':
            self.nasbench = NASBench101API.NASBench(api_loc)
            self.available_ops = ['input', 'conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3', 'output']
            self.max_num_vertices = 7
            self.max_num_edges = 9
            self.input_feat = 5

        if self.search_space == '201':
            self.nasbench = NASBench201API(api_loc, verbose=False)
            self.available_ops = ['input', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3',
                                  'output']
            self.max_num_vertices = 8
            self.input_feat = 7

        if self.search_space == 'TB101':
            self.nasbench = TransNASBenchAPI(api_loc)
            self.available_ops = ['input', '0', '1', '2', '3',
                                  'output']
            self.max_num_vertices = 8
            self.input_feat = 6



        self.model = PlainGCN(num_features= self.input_feat,num_classes=1,hidden=params["hidden_size"],num_fc_layers= params['layers']-1,num_conv_layers= params['layers'],dropout=params['dropout'],p_dim=params['p_dim']).to(self.device)

    def sample_arch(self):

        if self.search_space =='101':
            hash_list = list(self.nasbench.hash_iterator())
            hash_value = random.choice(hash_list)
            fixed_statistic = self.nasbench.fixed_statistics[hash_value]
            sampled_arch = (hash_value, fixed_statistic['module_adjacency'], fixed_statistic['module_operations'])


        elif self.search_space == '201':
            arch_list = list(enumerate(self.nasbench))
            sampled_arch = random.choice(arch_list)


        elif self.search_space == 'TB101':
            arch_list = []
            for index in range(3256, len(self.nasbench)):
                arch_list.append((index, self.nasbench.index2arch(index)))
            sampled_arch = random.choice(arch_list)


        return sampled_arch

    def eval_arch(self, arch, use_val_acc=False, model=None):
        start_time = time.time()
        if use_val_acc:
            if self.search_space == '101':
                info = self.nasbench.computed_statistics[arch[0]][108]
                val_acc = info[0]['final_validation_accuracy']
                test_acc = (info[0]['final_test_accuracy']+info[1]['final_test_accuracy']+info[2]['final_test_accuracy'])/3
                total_eval_time = info[0]['final_training_time']
                return val_acc, test_acc, total_eval_time

            elif self.search_space == '201':
                dataset = self.dataset if self.dataset != 'cifar10' else 'cifar10-valid'
                if self.dataset == 'imagenet16':
                    dataset = 'ImageNet16-120'
                info1 = self.nasbench.get_more_info(arch[0], dataset, iepoch=None, hp="200", is_random=False)
                val_acc = info1['valid-accuracy'] / 100.0
                dataset = self.dataset if self.dataset != 'cifar10' else 'cifar10'
                if self.dataset == 'imagenet16':
                    dataset = 'ImageNet16-120'
                info2 = self.nasbench.get_more_info(arch[0], dataset, iepoch=None, hp="200", is_random=False)
                test_acc = info2['test-accuracy'] / 100.0
                total_eval_time = (info1["train-all-time"] + info1["valid-per-time"])
                return val_acc, test_acc, total_eval_time

            elif self.search_space in ['TB101']:
                if self.task in ['class_scene', 'class_object', 'jigsaw']:
                    val_acc = self.nasbench.get_single_metric(arch[1], self.task, 'valid_top1')
                    test_acc = self.nasbench.get_single_metric(arch[1], self.task, 'test_top1')
                elif self.task == 'segmentsemantic':
                    val_acc = self.nasbench.get_single_metric(arch[1], self.task, 'valid_mIoU')
                    test_acc = self.nasbench.get_single_metric(arch[1], self.task, 'test_mIoU')
                elif self.task in ['normal', 'autoencoder']:
                    val_acc = self.nasbench.get_single_metric(arch[1], self.task, 'valid_ssim')
                    test_acc = self.nasbench.get_single_metric(arch[1], self.task, 'test_ssim')
                elif self.task == 'room_layout':
                    val_acc = self.nasbench.get_single_metric(arch[1], self.task, 'valid_neg_loss') * (-1) * 100
                    test_acc = self.nasbench.get_single_metric(arch[1], self.task, 'test_neg_loss') * (-1) * 100
                return val_acc, test_acc, 0

        else:

            model.eval()

            if self.search_space == '101':
                arch_hash = []
                for c in arch:
                    arch_hash.append(c[0])

                eval_data = Nb101Dataset(split=len(arch_hash), datatype='eval',
                                         no_sample=True,hash_list=arch_hash)
                loader = DataLoader(eval_data, batch_size=500, shuffle=False, num_workers=0)

            elif self.search_space == '201':
                arch_hash = []
                for c in arch:
                    arch_hash.append(c[0])
                eval_data = Nb201Dataset(split=len(arch_hash), data_type='eval', data_set=self.dataset, query_val=arch_hash)
                loader = DataLoader(eval_data, batch_size=500, shuffle=False, num_workers=0)


            elif self.search_space == 'TB101':
                arch_hash = []
                for c in arch:
                    arch_hash.append(c[0])
                eval_data = Trans101Dataset(split=len(arch_hash), data_type='eval', task=self.task, query_val=arch_hash)
                loader = DataLoader(eval_data, batch_size=500, shuffle=False, num_workers=0)


            with torch.no_grad():
                for step, batch in enumerate(loader):
                    batch = to_cuda(batch, device)
                    pred = model(batch)
                    measure = pred.detach()

            total_eval_time = time.time() - start_time
            return measure, total_eval_time, pred

    def cmp(self, x, y):
        ret = x[1] - y[1]
        if ret < 0:
            return -1
        elif ret > 0:
            return 1
        else:
            return 0

    def train(self, history, model,epochs):
        batch_size = (len(history) - 1) // 2 + 1
        archs_hash = []
        ars=[]
        accs=[]
        for h in history:
            archs_hash.append(h[0][0])
            ars.append(h[0])
            accs.append(h[1])

        if self.search_space == '101':
            train_set = Nb101Dataset(split=len(history), datatype='train', no_sample=True, hash_list=archs_hash)
        if self.search_space == '201':
            train_set = Nb201Dataset(split=len(history), data_type='train', query_val=archs_hash,diy=True)
        if self.search_space == 'TB101':
            train_set = Trans101Dataset(split=len(history), data_type='eval', query_val=archs_hash)

        train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=0, shuffle=True)
        optimizer = optim.Adam(model.parameters(), lr=params["lr"], weight_decay=params["wd"])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-4)
        iters = range(1, epochs+ 1)
        criterion = params["loss"]
        for _ in iters:
            losses = AverageMeter('loss')
            device = self.device
            scheduler.step()
            model.train()
            # finetune model
            optimizer.zero_grad()
            targets=[]
            outputs=[]
            for step,batch in enumerate(train_loader):
                target=batch["n_val_acc"].clone().detach().to(device)
                n=target.size(0)
                batch = to_cuda(batch,device)
                optimizer.zero_grad()
                output=model(batch)

                if criterion == 'mse':
                    cri = nn.MSELoss()
                    loss = cri(output, target.float())
                    loss.backward()
                    optimizer.step()
                    outputs.append(output.cpu().detach().numpy())
                    targets.append(target.cpu().detach().numpy())
                    losses.update(loss.item(), target.size(0))

                else:
                    if criterion == 'bpr':
                        cri = BPRLoss()
                        loss = cri(output, target.float())
                    if criterion == 'listmle':
                        loss = list_mle(output, target.float())
                    if criterion == 'wpair2':
                        loss = wpair2(output, target.float())

                    loss.backward()
                    optimizer.step()
                    losses.update(loss.data.item(),n)
                    outputs.append(output.squeeze().cpu().detach().numpy())
                    targets.append(target.squeeze().cpu().detach().numpy())

            outputs=np.concatenate(outputs)
            targets=np.concatenate(targets)

            tau=stats.kendalltau(targets,outputs,nan_policy='omit')[0]
            train_loss,train_tau = losses.avg,tau
            if _%20==0:print("epoch:",_,"loss:",train_loss,"tau:",train_tau)
            scheduler.step()
        print("training finish!")
        return model


    def predict(self, candidates, model):
        predictions, _, _ = self.eval_arch(candidates, use_val_acc=False, model=model)
        return predictions


class Random_NAS(NAS):

    def __init__(self, N, search_space, dataset, flops_limit,params,epochs,task, api_loc=None , num_init_archs=20, K=10, seed=0,device='cpu'):
        super(Random_NAS, self).__init__(N, search_space, dataset, flops_limit,params,epochs,task, api_loc=api_loc, seed=seed,device=device)
        self.num_init_archs = num_init_archs
        self.K = K

    def get_candidates(self):
        patience_factor = 10
        num = 200

        candidates = []
        for _ in range(patience_factor):
            for _ in range(int(num)):
                arch = self.sample_arch()
                if arch[0] not in self.visited:
                    candidates.append(arch)
                    self.visited.append(arch[0])
                if len(candidates) >= num:
                    return candidates
        return candidates

    def run(self):
        total_eval_time = 0
        history  = []
        while len(history) < self.num_init_archs:
            arch = self.sample_arch()
            if arch[0] not in self.visited:
                valid_acc, test_acc, eval_time = self.eval_arch(arch, use_val_acc=True)
                cur = (arch, valid_acc, test_acc, eval_time)
                total_eval_time += eval_time
                history.append(cur)
                self.visited.append(arch[0])

        if self.search_space in ['101','201']:
            best = max(history, key=lambda x: x[2])
        else:
            best = max(history, key=lambda x: x[1])

        print("best:", best)
        while len(history) < self.N:
            print("pool size:",len(history))
            eps = self.epochs #plain gcn 300
            candidates = self.get_candidates()
            model = copy.deepcopy(self.model)
            model = self.train(history,model,eps)
            candidate_predictions = self.predict(candidates, model)
            gts = [self.eval_arch(arch,use_val_acc=True)[1] for arch in candidates]
            h_tau= kendalltau(candidate_predictions.cpu().numpy(), np.array(gts))[0]
            print("--------")
            print("eval tau:",h_tau)
            print("eval wtau:", weightedtau(candidate_predictions.cpu().numpy(), np.array(gts))[0])
            print("N@10:",top_k_best_rank(candidate_predictions.cpu().numpy(),np.array(gts),10))
            print('Ref@10:', rel_at_1(candidate_predictions.cpu().numpy(),np.array(gts), 10))
            candidate_indices = np.argsort(candidate_predictions.cpu().numpy())

            for i in candidate_indices[-self.K:]:
                arch = candidates[i]
                valid_acc, test_acc, eval_time = self.eval_arch(arch, use_val_acc=True)
                cur = (arch, valid_acc, test_acc, eval_time)
                total_eval_time += eval_time
                history.append(cur)

            print("eval finish!")

            # test
            if self.search_space in ['101','201']:
                best = max(history, key=lambda x: x[2])
                history.sort(key=lambda x: x[2], reverse=True)
            else:
                best = max(history, key=lambda x: x[1])
                history.sort(key=lambda x: x[1], reverse=True)
            print("best:", best)
        # test
        if self.search_space in ['101', '201']:
            best = max(history, key=lambda x: x[2])
            history.sort(key=lambda x: x[2], reverse=True)
        else:
            best = max(history, key=lambda x: x[1])
            history.sort(key=lambda x:x[1],reverse=True)

        return best, history, total_eval_time


class Evolved_NAS(NAS):

    def __init__(self, N, search_space, population_size, tournament_size, dataset, flops_limit,params,epochs,task, api_loc=None, K=5, device='cpu', seed=None):
        super(Evolved_NAS, self).__init__(N, search_space, dataset, flops_limit,params,epochs=epochs,task=task, api_loc=api_loc, device=device, seed=seed)
        self.population_size = population_size
        self.tournament_size = tournament_size
        self.K = K

    def mutate(self, parent, p):
        if self.search_space == '101':
            if random.random() < p:
                while True:
                    old_matrix, old_ops = parent[1], parent[2]
                    idx_to_change = random.randrange(len(old_ops[1:-1])) + 1
                    entry_to_change = old_ops[idx_to_change]
                    possible_entries = [x for x in self.available_ops[1:-1] if x != entry_to_change]
                    new_entry = random.choice(possible_entries)
                    new_ops = copy.deepcopy(old_ops)
                    new_ops[idx_to_change] = new_entry
                    idx_to_change = random.randrange(sum(range(1, len(old_matrix))))
                    new_matrix = copy.deepcopy(old_matrix)
                    num_node = len(old_matrix)
                    idx_to_ij = {int(i*(num_node-1)-i*(i-1)/2+(j-i-1)): (i, j) for i in range(num_node) for j in range(i+1, num_node)}
                    i, j = idx_to_ij[idx_to_change]
                    new_matrix[i][j] = 1 if new_matrix[i][j] == 0 else 0
                    new_spec = NASBench101API.ModelSpec(matrix=new_matrix, ops=new_ops)
                    if self.nasbench.is_valid(new_spec):
                        spec_hash = new_spec.hash_spec(self.available_ops[1:-1])
                        child = (spec_hash, new_matrix, new_ops)
                        break
            else:
                child = parent

        if self.search_space == '201':
            if random.random() < p:
                nodes = parent[1].split('+')
                nodes = [node[1:-1].split('|') for node in nodes]
                nodes = [[op_and_input.split('~')[0] for op_and_input in node] for node in nodes]
                old_spec = [op for node in nodes for op in node]
                idx_to_change = random.randrange(len(old_spec))
                entry_to_change = old_spec[idx_to_change]
                possible_entries = [x for x in self.available_ops[1:-1] if x != entry_to_change]
                new_entry = random.choice(possible_entries)
                new_spec = copy.deepcopy(old_spec)
                new_spec[idx_to_change] = new_entry
                arch_str = '|{:}~0|+|{:}~0|{:}~1|+|{:}~0|{:}~1|{:}~2|'.format(*new_spec)
                i = self.nasbench.query_index_by_arch(arch_str)
                child = (i, arch_str)
            else:
                child = parent

        if self.search_space == 'TB101':
            if random.random() < p:
                nodes = parent[1].split('-')[-1]
                old_spec = [nodes[0],nodes[2],nodes[3],nodes[5],nodes[6],nodes[7]]
                idx_to_change = random.randrange(len(old_spec))
                entry_to_change = old_spec[idx_to_change]
                possible_entries = [x for x in self.available_ops[1:-1] if x != entry_to_change]
                new_entry = random.choice(possible_entries)
                new_spec = copy.deepcopy(old_spec)
                new_spec[idx_to_change] = new_entry
                arch_str_post = '{:}_{:}{:}_{:}{:}{:}'.format(*new_spec)
                arch_str = parent[1].split('-')[0]+'-'+parent[1].split('-')[1]+'-'+arch_str_post
                idx = self.nasbench.arch2index(arch_str)
                child = (idx, arch_str)
            else:
                child = parent

        return child


    def get_candidates(self, arch_pool):
        p = 1.0
        num_arches_to_mutate = 1
        patience_factor = 5000
        if search_space in ['101','201']:
            num = 200
        else:
            num = 100
        candidates = []
        for _ in range(patience_factor):
            samples  = random.sample(arch_pool, self.tournament_size)
            parents = [arch[0] for arch in sorted(samples, key=cmp_to_key(self.cmp), reverse=True)[:num_arches_to_mutate]]
            for parent in parents:
                for _ in range(int(num / num_arches_to_mutate)):
                    child = self.mutate(parent, p)
                    if child[0] not in self.visited:
                        candidates.append(child)
                        self.visited.append(child[0])
                    if len(candidates) >= num:
                        return candidates
        return candidates

    def run(self):
        total_eval_time = 0
        history  = []
        population = collections.deque()
        self.visited = []
        while len(history) < self.population_size:
            arch = self.sample_arch()
            if arch[0] not in self.visited:
                valid_acc, test_acc, eval_time = self.eval_arch(arch, use_val_acc=True)
                cur = (arch, valid_acc, test_acc, eval_time)
                total_eval_time += eval_time
                population.append(cur)
                history.append(cur)
                self.visited.append(arch[0])

        if self.search_space in ['101', '201']:
            best = max(history, key=lambda x: x[2])
        else:
            best = max(history, key=lambda x: x[1])
        print("best:", best)
        while len(history) < self.N:
            eps = self.epochs
            candidates = self.get_candidates(population)
            model = copy.deepcopy(self.model)
            model = self.train(history,model,eps)
            candidate_predictions = self.predict(candidates, model)
            gts = [self.eval_arch(arch,use_val_acc=True)[1] for arch in candidates]
            h_tau = kendalltau(candidate_predictions.cpu().numpy(), np.array(gts))[0]
            print("--------")
            print("eval tau:", h_tau)
            print("eval wtau:", weightedtau(candidate_predictions.cpu().numpy(), np.array(gts))[0])
            print("N@10:", top_k_best_rank(candidate_predictions.cpu().numpy(), np.array(gts), 10))
            print('Ref@10:', rel_at_1(candidate_predictions.cpu().numpy(), np.array(gts), 10))
            candidate_indices = np.argsort(candidate_predictions.cpu().numpy())

            print('eval finish!')
            print("len:",len(history))
            for i in candidate_indices[-self.K:]:
                arch = candidates[i]
                valid_acc, test_acc, eval_time = self.eval_arch(arch, use_val_acc=True)
                cur = (arch, valid_acc, test_acc, eval_time)
                total_eval_time += eval_time
                population.append(cur)
                history.append(cur)
                population.popleft()

            if self.search_space in ['101', '201']:
                best = max(history, key=lambda x: x[2])
            else:
                best = max(history, key=lambda x: x[1])
            print("best:",best)
        if self.search_space in ['101', '201']:
            history.sort(key=lambda x:x[2],reverse=True)
            best = max(history, key=lambda x: x[2])
        else:
            history.sort(key=lambda x: x[1], reverse=True)
            best = max(history, key=lambda x: x[1])
        return best, history, total_eval_time

def get_params():
    parser = ArgumentParser()
    # exp and dataset
    parser.add_argument("--exp_name", type=str, default='search')
    parser.add_argument("--bench", type=str, default='201',choices=['101','201','TB101'])
    parser.add_argument("--test_split", type=str, default='all')
    parser.add_argument("--dataset", type=str, default='cifar10',choices=['cifar10','cifar100','imagenet16'])
    parser.add_argument('--task', default='autoencoder', type=str,
                        choices=['image classification', 'class_scene', 'class_object', 'jigsaw', 'segmentsemantic',
                                 'normal', 'autoencoder', 'room_layout'])
    parser.add_argument('--outdir', default='./', type=str, help='output directory')
    # training settings
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--epochs", default=200, type=int)
    parser.add_argument("--layers", default=4, type=int)
    parser.add_argument("--hidden_size",default=144,type=int)
    parser.add_argument("--p-dim",default=128,type=int)
    parser.add_argument("--lr", default=1e-3, type=float)
    parser.add_argument("--wd", default=1e-3, type=float)
    parser.add_argument("--loss", default='mse', type=str,
                        choices=['mse', 'bpr', 'listmle', 'wpair2'])
    parser.add_argument("--dropout", default=0.15, type=float)
    parser.add_argument("--train_batch_size", default=10, type=int)
    parser.add_argument("--test_batch_size", default=10240, type=int)
    parser.add_argument("--runs", default=20, type=int)
    #search settings
    parser.add_argument('--N', default=100, type=int, help='the number of searched archs')
    parser.add_argument('--K', default=10, type=int, help='the number of added archs')
    parser.add_argument('--population_size', default=20, type=int)
    parser.add_argument('--tournament_size', default=5, type=int)
    parser.add_argument('--search_algo', default='r', type=str, choices=['r', 'rea'], help='search algorithm')
    parser.add_argument('--EMA_momentum', default=0.9, type=float)
    parser.add_argument('--flops_limit', default=600e6, type=float)

    args , _ = parser.parse_known_args()
    return args





if __name__ == '__main__':
    params = vars(get_params())
    tune_params = nni.get_next_parameter()
    params.update(tune_params)

    # device
    device = torch.device(
        torch.device('cuda:' + str(params['gpu'])) if torch.cuda.is_available() else torch.device('cpu'))

    if params['bench'] == '101':
        api_loc = 'datasets/nasbench101/nasbench_full.tfrecord'
    if params['bench'] == '201':
        api_loc = 'datasets/nasbench201/NAS-Bench-201-v1_1-096897.pth'
    else:
        api_loc = "tb101_api/api/api_home/transnas-bench_v10141024.pth"

    search_space = params['bench']
    search_algo = params['search_algo']
    N = params['N']
    seed = params['seed']
    save_dir = os.path.join(params['outdir'], f'{search_space}_{search_algo}_N{N}_seed{seed}')
    epochs = params['epochs']


    bests = []
    for i in range(params['runs']):
        seed = i
        set_seed(i)
        print("run:",i)
        if params['search_algo'] == 'r':
            nas=Random_NAS(N, search_space,params['dataset'], params['flops_limit'],task=params['task'],params=params,epochs=epochs, api_loc=api_loc,seed=i,K=params['K'],device=device)
        if params['search_algo'] == 'rea':
            nas=Evolved_NAS(N, search_space,population_size=params['population_size'],tournament_size=params['tournament_size'],dataset=params['dataset'],task=params['task'], flops_limit=params['flops_limit'],K=params['K'],params=params,epochs=epochs, api_loc=api_loc,seed=seed,device=device)
        begin_time = time.time()
        best, history, total_eval_time = nas.run()
        end_time = time.time()
        print("==========best======")
        print(best)
        if search_space in ['101','201']:
            bests.append(best[-2])
        else:
            bests.append(best[-3])
        print(total_eval_time)

    print('=====final=====')
    print("avg acc:",np.mean(np.array(bests)))
    print("best acc:",np.max(np.array(bests)))