import numpy as np
import torch
from copy import deepcopy
import random
import logging
import argparse
import os
import sys
import time
from models import get_train_val_loaders
from models import ResNet18_child


class EvoSearch(object):

    def __init__(self, args, net, search_way):
        super(EvoSearch, self).__init__()
        self.search_way = search_way
        self.args = args
        self.net = net
        self.size = args.size
        self.num_pool = args.num_pool
        self.num_pool2 = args.num_pool2
        self.num_topk = args.num_topk
        self.num_mutation = args.num_mutation
        self.mutation_prob = args.mutation_prob
        self.cross_prob = args.cross_prob
        self.increase_prob = args.increase_prob
        self.num_crossover = args.num_crossover

        self.history_topk = []
        self.seen_bwpath = []

        self.model_path = args.model_path

    def initialize_net(self, bwpath):
        net = self.net(self.args, bwpath)
        net.cuda()
        return net

    def load_net(self, path, bwpath):
        ckpt = torch.load(path, map_location=torch.device('cpu'))
        net = self.initialize_net(bwpath)
        net.load_state_dict(ckpt['net'])
        net.cuda()
        return net

    def random_pool(self, pool_size):
        newpool = []
        while len(newpool) < pool_size:
            bwpath_list = self.sampling()
            if (bwpath_list not in newpool) and (bwpath_list not in self.seen_bwpath):
                newpool.append(bwpath_list)
        return newpool

    def mutation_pool(self, pool_use, pool_size):
        newpool = []
        for _ in range(pool_size * 3):
            bwpath_list = deepcopy(pool_use[random.choice(range(len(pool_use)))])
            num_path = len(bwpath_list)
            pos1, pos2 = None, None
            if num_path == 1:
                pos1 = bwpath_list[0]
            elif num_path == 2:
                pos1, pos2 = bwpath_list[0], bwpath_list[1]
            else:
                raise NotImplementedError

            if np.random.random() < self.mutation_prob:
                random_num = np.random.uniform(low=0.0, high=1.0)
                if random_num < 0.2:
                    if pos1[0] > 0 and pos1[0] != pos1[1]:
                        bwpath_list[0][0] = pos1[0] - 1

                elif 0.2 <= random_num < 0.4:
                    if pos1[1] > 0:
                        bwpath_list[0][1] = pos1[1] - 1

                elif 0.4 <= random_num < 0.6:
                    if pos1[0] < self.size - 1:
                        bwpath_list[0][0] = pos1[0] + 1

                elif 0.6 <= random_num < 0.8:
                    if pos1[0] != pos1[1]:
                        bwpath_list[0][1] = pos1[1] + 1

                elif 0.8 <= random_num < 0.85:
                    if pos1[0] > 1 and pos1[0] > pos1[1] + 1:
                        bwpath_list[0][0] = pos1[0] - 2

                elif 0.85 <= random_num < 0.9:
                    if pos1[1] > 1:
                        bwpath_list[0][1] = pos1[1] - 2

                elif 0.9 <= random_num < 0.95:
                    if pos1[0] < self.size - 2:
                        bwpath_list[0][0] = pos1[0] + 2

                elif random_num >= 0.95:
                    if pos1[0] > pos1[1] + 1:
                        bwpath_list[0][1] = pos1[1] + 2

                else:
                    raise NotImplementedError

                if pos2 is not None:
                    random_num2 = np.random.uniform(low=0.0, high=1.0)
                    if random_num2 < 0.2:
                        if pos2[0] > 0 and pos2[0] != pos2[1]:
                            bwpath_list[1][0] = pos2[0] - 1

                    elif 0.2 <= random_num2 < 0.4:
                        if pos2[1] > 0:
                            bwpath_list[1][1] = pos2[1] - 1

                    elif 0.4 <= random_num2 < 0.6:
                        if pos2[0] < self.size - 1:
                            bwpath_list[1][0] = pos2[0] + 1

                    elif 0.6 <= random_num2 < 0.8:
                        if pos2[0] != pos2[1]:
                            bwpath_list[1][1] = pos2[1] + 1

                    elif 0.8 <= random_num2 < 0.85:
                        if pos2[0] > 1 and pos2[0] > pos2[1] + 1:
                            bwpath_list[1][0] = pos2[0] - 2

                    elif 0.85 <= random_num2 < 0.9:
                        if pos2[1] > 1:
                            bwpath_list[1][1] = pos2[1] - 2

                    elif 0.9 <= random_num2 < 0.95:
                        if pos2[0] < self.size - 2:
                            bwpath_list[1][0] = pos2[0] + 2

                    elif random_num2 >= 0.95:
                        if pos2[0] > pos2[1] + 1:
                            bwpath_list[1][1] = pos2[1] + 2

                    else:
                        raise NotImplementedError

                    if bwpath_list[0] == bwpath_list[1]:
                        del bwpath_list[1]

            if (bwpath_list not in newpool) and (bwpath_list not in self.seen_bwpath):
                newpool.append(bwpath_list)
            if len(newpool) == pool_size:
                break
        return newpool

    def crossover_or_increase_pool(self, pool_use, pool_size):
        newpool = []
        for _ in range(pool_size * 3):
            bwpath_list = deepcopy(pool_use[random.choice(range(len(pool_use)))])
            num_path = len(bwpath_list)

            if num_path == 1:
                rnd_num = np.random.uniform(low=0.0, high=1.0)
                if rnd_num > self.increase_prob:
                    froms = random.choice(range(self.size))
                    tos = random.choice(range(froms + 1))
                    if [froms, tos] not in bwpath_list:
                        bwpath_list.append([froms, tos])

            elif num_path == 2:
                pos1, pos2 = bwpath_list[0], bwpath_list[1]
                random_num = np.random.uniform(low=0.0, high=1.0)
                if random_num > self.cross_prob:
                    if pos1[0] > pos2[1] and pos2[0] > pos1[1]:
                        bwpath_list[0][1] = pos2[1]
                        bwpath_list[1][1] = pos1[1]

            else:
                raise NotImplementedError

            if (bwpath_list not in newpool) and (bwpath_list not in self.seen_bwpath):
                newpool.append(bwpath_list)
            if len(newpool) == pool_size:
                break
        return newpool

    def search(self, max_search_iter, tloader_epe, tloader_nassnn, vloader):
        # initialize the pool
        pool = self.random_pool(self.num_pool)
        escore = []
        nscore = []
        for bwpath_list in pool:
            self.seen_bwpath.append(bwpath_list)
            bwpath = torch.zeros(self.size, self.size)
            for bbw in bwpath_list:
                bwpath[bbw[0]][bbw[1]] = 1.
            fnet = self.initialize_net(bwpath)
            escore.append(epe_score(fnet, tloader_epe))
            fnet.zero_grad()
            nscore.append(nassnn_score(fnet, tloader_nassnn, self.args))
            del fnet
        itopk_idx1 = np.argsort(escore)[::-1][:int(self.num_pool2 / 2)]
        itopk_idx2 = np.argsort(nscore)[::-1][:int(self.num_pool2 / 2)]
        pool_e = [pool[idx] for idx in itopk_idx1]
        pool_n = [pool[idx] for idx in itopk_idx2]
        pool_en = pool_e + pool_n
        pool2 = []
        for bb in pool_en:
            if bb not in pool2:
                pool2.append(bb)
        logging.info(f'after zero-shot phase, the num of arcs is {len(pool2)}')

        # one-shot
        topk_pool, topk_acc = [], []
        for it in range(max_search_iter):
            logging.info(f'search_iter: {it}')
            acc_list = []
            for idx, bwpath_list in enumerate(pool2):
                bwpath = torch.zeros(self.size, self.size)
                for bbw in bwpath_list:
                    bwpath[bbw[0]][bbw[1]] = 1.
                snet = self.load_net(self.model_path, bwpath)
                acc = infer(snet, vloader)
                acc_list.append(acc)
                logging.info(f'{idx} \t {bwpath_list} \t {acc:.4f}')


            temp_pool = topk_pool + pool2
            temp_acc = topk_acc + acc_list

            topk_idx = np.argsort(temp_acc)[::-1][:self.num_topk]  # 10
            bwpath_acc = [[temp_pool[idx], temp_acc[idx]] for idx in topk_idx]
            topk_pool = [temp_pool[idx] for idx in topk_idx]
            topk_acc = [temp_acc[idx] for idx in topk_idx]

            logging.info(f'\ttop-{self.num_topk} paths')
            for bwpath_list, acc in zip(topk_pool, topk_acc):
                logging.info(f'\t{bwpath_list} {acc:.4f}')

            self.history_topk.append(bwpath_acc)

            # prepare next pool
            if self.search_way == 'random':
                logging.info(f'[newpool] random: {self.args.random_pool_per_iter}')
                pool2 = self.random_pool(self.args.random_pool_per_iter)
                for bwpath_list in pool2:
                    self.seen_bwpath.append(bwpath_list)
            else:
                mut_pool = self.mutation_pool(topk_pool, self.num_mutation)
                for bwpath_list in mut_pool:
                    self.seen_bwpath.append(bwpath_list)
                cro_pool = self.crossover_or_increase_pool(topk_pool, self.num_crossover)
                for bwpath_list in cro_pool:
                    self.seen_bwpath.append(bwpath_list)
                rnd_pool = self.random_pool(self.num_pool2 - len(mut_pool) - len(cro_pool))
                for bwpath_list in rnd_pool:
                    self.seen_bwpath.append(bwpath_list)
                logging.info(
                    f'[new pool] mutation: {len(mut_pool)}, crossover: {len(cro_pool)}, random: {len(rnd_pool)}')

                pool2 = mut_pool + cro_pool + rnd_pool

        return self.history_topk

    def sampling(self):
        bw_list = []
        num_path = random.choice(range(1, 3))
        if num_path == 1:
            x1 = random.choice(range(self.size))
            x2 = random.choice(range(self.size))
            if x1 > x2:
                bw_list.append([x1, x2])
            else:
                bw_list.append([x2, x1])

        elif num_path == 2:
            x1 = random.choice(range(self.size))
            x2 = random.choice(range(self.size))
            x3 = random.choice(range(self.size))
            x4 = random.choice(range(self.size))
            if x1 > x2:
                bw_list.append([x1, x2])
            else:
                bw_list.append([x2, x1])
            if x3 > x4:
                bw_list.append([x3, x4])
            else:
                bw_list.append([x4, x3])

            if bw_list[0] == bw_list[1]:
                del bw_list[1]
        else:
            raise NotImplementedError

        return bw_list


def infer(net, loader):
    net.eval()
    test_sum = 0
    correct_sum = 0
    with torch.no_grad():
        for img, label in loader:
            img = img.cuda()
            label = label.cuda()
            output, _ = net(img)
            correct_sum += (output.argmax(dim=1) == label).float().sum().item()
            test_sum += label.numel()
        test_accuracy = correct_sum / test_sum
    return test_accuracy


def epe_score(net, loader, ):

    def get_batch_jacobian(model, x, ttarget,):
        model.zero_grad()

        x.requires_grad_(True)

        y, _ = model(x)

        y.backward(torch.ones_like(y))
        jacob = x.grad.detach()

        return jacob, ttarget.detach(), y.shape[-1]

    def eval_score_perclass(jacob, llabels=None, num_classes=10):
        k = 1e-5

        per_class = {}
        for i, label in enumerate(llabels[0]):
            if label in per_class:
                per_class[label] = np.vstack((per_class[label], jacob[i]))
            else:
                per_class[label] = jacob[i]

        ind_corr_matrix_score = {}
        for c in per_class.keys():
            try:
                corrs = np.array(np.corrcoef(per_class[c]))
                sco = np.sum(np.log(abs(corrs) + k))  # /len(corrs)
                if num_classes > 100:
                    sco /= len(corrs)
            except:  # defensive programming
                continue
            ind_corr_matrix_score[c] = sco

        # per class-corr matrix A and B
        score = 0
        ind_corr_matrix_score_keys = ind_corr_matrix_score.keys()
        if num_classes <= 100:

            for c in ind_corr_matrix_score_keys:
                # B)
                score += np.absolute(ind_corr_matrix_score[c])
        else:
            for c in ind_corr_matrix_score_keys:
                # A)
                for cj in ind_corr_matrix_score_keys:
                    score += np.absolute(ind_corr_matrix_score[c] - ind_corr_matrix_score[cj])

            if len(ind_corr_matrix_score_keys) > 0:
                # should divide by number of classes seen
                score /= len(ind_corr_matrix_score_keys)

        return score

    train_batch = iter(loader)
    inputs, targets = next(train_batch)
    inputs, targets = inputs.cuda(), targets.cuda()

    jacobs = []
    labels = []

    try:

        jacobs_batch, target, n_classes = get_batch_jacobian(net, inputs, targets)
        jacobs.append(jacobs_batch.reshape(jacobs_batch.size(0), -1).cpu().numpy())

        labels.append(target.cpu().numpy())

        jacobs = np.concatenate(jacobs, axis=0)

        s = eval_score_perclass(jacobs, labels, n_classes)

    except Exception as e:
        print(e)
        s = np.nan

    return s


def nassnn_score(net, loader, argg):

    def logdet(K):
        s, ld = np.linalg.slogdet(K)
        return ld

    search_batchsize = argg.batch_size * 2
    repeat = argg.repeat

    neuron_type = 'SpikeNode'

    with torch.no_grad():

        net.K = np.zeros((search_batchsize, search_batchsize))
        net.num_action = 0

        def computing_K_eachtime(mod, inp, out):
            if isinstance(out, tuple):
                out = out[0]
            out = out.view(out.size(0), -1)
            batch_num, neuron_num = out.size()
            x = (out > 0).float()

            full_matrix = torch.ones((search_batchsize, search_batchsize)).cuda() * neuron_num
            sparsity = (x.sum(1) / neuron_num).unsqueeze(1)
            norm_K = (sparsity @ (1 - sparsity.t())) + ((1 - sparsity) @ sparsity.t())
            rescale_factor = torch.div(0.5 * torch.ones((search_batchsize, search_batchsize)).cuda(), norm_K + 1e-3)
            K1_0 = x @ (1 - x.t())
            K0_1 = (1 - x) @ x.t()
            Hamming_distance_array = K0_1 + K1_0
            K_total = (full_matrix - rescale_factor * Hamming_distance_array)

            net.K = net.K + K_total.cpu().numpy()
            net.num_action += 1

        batches_score_list = []

        for name, module in net.named_modules():
            if neuron_type in str(type(module)):
                module.register_forward_hook(computing_K_eachtime)

        for j in range(repeat):
            net.K = np.zeros((search_batchsize, search_batchsize))
            net.num_action = 0
            data_iterator = iter(loader)
            inputs, targets = next(data_iterator)
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = net(inputs)
            batches_score_list.append(logdet(net.K / net.num_action))

        score = np.mean(batches_score_list)

    return score


def set_seed(_seed_=2023):
    np.random.seed(_seed_)
    random.seed(_seed_)
    torch.manual_seed(_seed_)
    torch.cuda.manual_seed(_seed_)


if __name__ == '__main__':

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-model_name', type=str, default='cifar10t4')
    parser.add_argument('-model_path', type=str, default='/data_smr/lw/project/res_bw_ICML2023/checkpoints/evo_search/models/cifar10t4.pt')
    parser.add_argument('-seed', type=int, default=2023)

    parser.add_argument('-T', type=int, default=4)
    parser.add_argument('-num_class', type=int, default=10)
    parser.add_argument('-if_static', action='store_false', default=True)
    parser.add_argument('-batch_size', type=int, default=128)
    parser.add_argument('-size', type=int, default=16)
    parser.add_argument('-repeat', type=int, default=2)
    parser.add_argument('-zeroshot', type=str, default='nassnn', help='[nassnn, epe]')
    parser.add_argument('-search_way', type=str, default='evolution', help='[evolution, random]')
    parser.add_argument('-random_pool_per_iter', type=int, default=20)
    parser.add_argument('-dataset', type=str, default='CIFAR10')
    parser.add_argument('-log_dir_prefix', type=str, default='/data_smr/lw/project/res_bw_ICML2023/checkpoints/evo_search/models/')
    parser.add_argument('-dir_name', type=str, default='cifar10t4_search')

    # evoluation search setting
    parser.add_argument('-max_search_iter', type=int, default=20)
    parser.add_argument('-num_pool', type=int, default=2000)
    parser.add_argument('-num_pool2', type=int, default=20)
    parser.add_argument('-num_mutation', type=int, default=10)
    parser.add_argument('-mutation_prob', type=float, default=0.2)
    parser.add_argument('-cross_prob', type=float, default=0.5)
    parser.add_argument('-increase_prob', type=float, default=0.5)
    parser.add_argument('-num_crossover', type=int, default=10)
    parser.add_argument('-num_topk', type=int, default=10)

    args = parser.parse_args()

    # names
    T = args.T
    num_class = args.num_class
    if_static = args.if_static
    dataset_name = args.dataset
    log_dir_prefix = args.log_dir_prefix
    dir_name = args.dir_name

    # path setting
    log_dir = os.path.join(log_dir_prefix, dir_name)
    print(f'log location: {log_dir}')
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    # logging setting
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'log_{}.txt').format(time.strftime("%Y%m%d-%H%M%S")))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    # logging args
    for arg, val in args.__dict__.items():
        logging.info(arg + '.' * (100 - len(arg) - len(str(val))) + str(val))

    # dataset seed
    set_seed(args.seed)

    # dataloader
    epe_tloader, _ = get_train_val_loaders(args, dataset_name, search=False)
    nassnn_tloader, _ = get_train_val_loaders(args, dataset_name, search=False, if_nassnn=True)
    _, infer_vloader = get_train_val_loaders(args, dataset_name, search=True)

    # # dataloader for tiny-imagenet
    # train_dataset, train_data, val_data = get_train_val_loaders(args, dataset_name, search=True)
    # epe_tloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, pin_memory=True,
    #                                           num_workers=4)
    # nassnn_tloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True,
    #                                              num_workers=4)
    # infer_vloader = torch.utils.data.DataLoader(dataset=val_data, batch_size=args.batch_size, shuffle=False,
    #                                             num_workers=4, drop_last=False, pin_memory=True)

    search_net = ResNet18_child
    worker = EvoSearch(args, search_net, args.search_way)
    history = worker.search(args.max_search_iter, epe_tloader, nassnn_tloader, infer_vloader)
