
import time
import pickle
import logging
import os
import numpy as np
import torch
import torch.nn as nn


from collections import OrderedDict
from yaml import safe_dump
from yacs.config import load_cfg, CfgNode#, _to_dict
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.engine.inference import _accumulate_predictions_from_multiple_gpus
from maskrcnn_benchmark.modeling.backbone.nas import get_layer_name
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, get_world_size, all_gather
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from maskrcnn_benchmark.utils.flops import profile


choice = lambda x:x[np.random.randint(len(x))] if isinstance(x,tuple) else choice(tuple(x))


def gather_candidates(all_candidates):
    all_candidates = all_gather(all_candidates)
    all_candidates = [cand for candidates in all_candidates for cand in candidates]
    return list(set(all_candidates))


def gather_stats(all_candidates):
    all_candidates = all_gather(all_candidates)
    reduced_statcs = {}
    for candidates in all_candidates:
        reduced_statcs.update(candidates) # will replace the existing key with last value if more than one exists
    return reduced_statcs


def compute_on_dataset(model, rngs, data_loader, device=cfg.MODEL.DEVICE):
    model.eval()
    results_dict = {}
    cpu_device = torch.device("cpu")
    for _, batch in enumerate(data_loader):
        images, targets, image_ids = batch
        with torch.no_grad():
            output = model(images.to(device), rngs=rngs)
            output = [o.to(cpu_device) for o in output]
        results_dict.update(
            {img_id: result for img_id, result in zip(image_ids, output)}
        )
    return results_dict


def bn_statistic(model, rngs, data_loader, device=cfg.MODEL.DEVICE, max_iter=500):
    for name, param in model.named_buffers():
        if 'running_mean' in name:
            nn.init.constant_(param, 0)
        if 'running_var' in name:
            nn.init.constant_(param, 1)

    model.train()
    for iteration, (images, targets, _) in enumerate(data_loader, 1):
        images = images.to(device)
        targets = [target.to(device) for target in targets]
        with torch.no_grad():
            loss_dict = model(images, targets, rngs)
        if iteration >= max_iter:
            break

    return model


def inference(
        model,
        rngs,
        data_loader,
        iou_types=("bbox",),
        box_only=False,
        device="cuda",
        expected_results=(),
        expected_results_sigma_tol=4,
        output_folder=None,
):

    # convert to a torch.device for efficiency
    device = torch.device(device)
    dataset = data_loader.dataset
    predictions = compute_on_dataset(model, rngs, data_loader, device)
    # wait for all processes to complete before measuring the time
    synchronize()

    predictions = _accumulate_predictions_from_multiple_gpus(predictions)
    if not is_main_process():
        return

    extra_args = dict(
        box_only=box_only,
        iou_types=iou_types,
        expected_results=expected_results,
        expected_results_sigma_tol=expected_results_sigma_tol,
    )

    return evaluate(dataset=dataset,
                    predictions=predictions,
                    output_folder=output_folder,
                    **extra_args)


def fitness(cfg, model, rngs, val_loaders):
    iou_types = ("bbox",)
    if cfg.MODEL.MASK_ON:
        iou_types = iou_types + ("segm",)
    for data_loader_val in val_loaders:
        results = inference(
            model,
            rngs,
            data_loader_val,
            iou_types=iou_types,
            box_only=False,
            device=cfg.MODEL.DEVICE,
            expected_results=cfg.TEST.EXPECTED_RESULTS,
            expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
        )
        synchronize()

    return results


class EvolutionTrainer(object):
    def __init__(self, cfg, model, flops_limit=None, is_distributed=True):

        self.log_dir = cfg.OUTPUT_DIR
        self.checkpoint_name = os.path.join(self.log_dir,'evolution.pth')
        self.is_distributed = is_distributed

        self.states = model.module.mix_nums if is_distributed else model.mix_nums
        self.supernet_state_dict = pickle.loads(pickle.dumps(model.state_dict()))
        self.flops_limit = flops_limit
        self.model = model

        self.candidates = []
        self.vis_dict = {}

        self.max_epochs = cfg.SEARCH.MAX_EPOCH
        self.select_num = cfg.SEARCH.SELECT_NUM
        self.population_num = cfg.SEARCH.POPULATION_NUM/get_world_size()
        self.mutation_num = cfg.SEARCH.MUTATION_NUM/get_world_size()
        self.crossover_num = cfg.SEARCH.CROSSOVER_NUM/get_world_size()
        self.mutation_prob = cfg.SEARCH.MUTATION_PROB/get_world_size()

        self.keep_top_k = {self.select_num:[], 50:[]}
        self.epoch=0
        self.cfg = cfg

    def save_checkpoint(self):
        if not is_main_process():
            return
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        info = {}
        info['candidates'] = self.candidates
        info['vis_dict'] = self.vis_dict
        info['keep_top_k'] = self.keep_top_k
        info['epoch'] = self.epoch
        torch.save(info, self.checkpoint_name)
        print('Save checkpoint to', self.checkpoint_name)

    def load_checkpoint(self):
        if not os.path.exists(self.checkpoint_name):
            return False
        info = torch.load(self.checkpoint_name)
        self.candidates = info['candidates']
        self.vis_dict = info['vis_dict']
        self.keep_top_k = info['keep_top_k']
        self.epoch = info['epoch']
        print('Load checkpoint from', self.checkpoint_name)
        return True

    def legal(self, cand):
        assert isinstance(cand,tuple) and len(cand)==len(self.states)
        if cand in self.vis_dict:
            return False

        if self.flops_limit is not None:
            net = self.model.module.backbone if self.is_distributed else self.model.backbone
            inp = (1, 3, 224, 224)
            flops, params = profile(net, inp, extra_args={'paths': list(cand)})
            flops = flops/1e6
            print('flops:',flops)
            if flops>self.flops_limit:
                return False

        return True

    def update_top_k(self, candidates, *, k, key, reverse=False):
        assert k in self.keep_top_k
        # print('select ......')
        t = self.keep_top_k[k]
        t += candidates
        t.sort(key=key,reverse=reverse)
        self.keep_top_k[k]=t[:k]

    def eval_candidates(self, train_loader, val_loader):
        for cand in self.candidates:
            t0 = time.time()

            # load back supernet state dict
            self.model.load_state_dict(self.supernet_state_dict)
            # bn_statistic
            model = bn_statistic(self.model, list(cand), train_loader)
            # fitness
            evals = fitness(cfg, model, list(cand), val_loader)

            if is_main_process():
                acc = evals[0].results['bbox']['AP']
                self.vis_dict[cand] = acc
                print('candiate ', cand)
                print('time: {}s'.format(time.time() - t0))
                print('acc ', acc)

    def stack_random_cand(self, random_func, *, batchsize=10):
        while True:
            cands = [random_func() for _ in range(batchsize)]
            for cand in cands:
                yield cand

    def random_can(self, num):
        # print('random select ........')
        candidates = []
        cand_iter = self.stack_random_cand(lambda:tuple(np.random.randint(i) for i in self.states))
        while len(candidates)<num:
            cand = next(cand_iter)

            if not self.legal(cand):
                continue
            candidates.append(cand)
            #print('random {}/{}'.format(len(candidates),num))

        # print('random_num = {}'.format(len(candidates)))
        return candidates

    def get_mutation(self, k, mutation_num, m_prob):
        assert k in self.keep_top_k
        # print('mutation ......')
        res = []
        iter = 0
        max_iters = mutation_num*10

        def random_func():
            cand = list(choice(self.keep_top_k[k]))
            for i in range(len(self.states)):
                if np.random.random_sample()<m_prob:
                    cand[i] = np.random.randint(self.states[i])
            return tuple(cand)

        cand_iter = self.stack_random_cand(random_func)
        while len(res)<mutation_num and max_iters>0:
            cand = next(cand_iter)
            if not self.legal(cand):
                continue
            res.append(cand)
            #print('mutation {}/{}'.format(len(res),mutation_num))
            max_iters-=1

        # print('mutation_num = {}'.format(len(res)))
        return res

    def get_crossover(self, k, crossover_num):
        assert k in self.keep_top_k
        # print('crossover ......')
        res = []
        iter = 0
        max_iters = 10 * crossover_num

        def random_func():
            p1=choice(self.keep_top_k[k])
            p2=choice(self.keep_top_k[k])
            return tuple(choice([i,j]) for i,j in zip(p1,p2))

        cand_iter = self.stack_random_cand(random_func)
        while len(res)<crossover_num and max_iters>0:
            cand = next(cand_iter)
            if not self.legal(cand):
                continue
            res.append(cand)
            #print('crossover {}/{}'.format(len(res),crossover_num))
            max_iters-=1

        # print('crossover_num = {}'.format(len(res)))
        return res

    def train(self, train_loader, val_loader):
        logger = logging.getLogger("maskrcnn_benchmark.evolution")

        if not self.load_checkpoint():
            self.candidates = gather_candidates(self.random_can(self.population_num))

        while self.epoch<self.max_epochs:
            self.eval_candidates(train_loader, val_loader)
            self.vis_dict = gather_stats(self.vis_dict)

            self.update_top_k(self.candidates, k=self.select_num, key=lambda x:1-self.vis_dict[x])
            self.update_top_k(self.candidates, k=50, key=lambda x:1-self.vis_dict[x])

            if is_main_process():
                logger.info('Epoch {} : top {} result'.format(self.epoch+1, len(self.keep_top_k[self.select_num])))
                for i,cand in enumerate(self.keep_top_k[self.select_num]):
                    logger.info('     No.{} {} perf = {}'.format(i+1, cand, self.vis_dict[cand]))

            mutation = gather_candidates(self.get_mutation(self.select_num, self.mutation_num, self.mutation_prob))
            crossover = gather_candidates(self.get_crossover(self.select_num, self.crossover_num))
            rand = gather_candidates(self.random_can(self.population_num - len(mutation) - len(crossover)))

            self.candidates = mutation + crossover + rand

            self.epoch+=1
            self.save_checkpoint()

    def save_candidates(self, cand, template):
        paths = self.keep_top_k[self.select_num][cand-1]

        with open(template, "r") as f:
            super_cfg = load_cfg(f)

        search_spaces = {}
        for mix_ops in super_cfg.MODEL.BACKBONE.LAYER_SEARCH:
            search_spaces[mix_ops] = super_cfg.MODEL.BACKBONE.LAYER_SEARCH[mix_ops]
        search_layers = super_cfg.MODEL.BACKBONE.LAYER_SETUP

        layer_setup = []
        for i, layer in enumerate(search_layers):
            name, setup = get_layer_name(layer, search_spaces)
            if not isinstance(name, list):
                name = [name]
            name = name[paths[i]]

            layer_setup.append("('{}', {})".format(name, str(setup)[1:-1]))
        super_cfg.MODEL.BACKBONE.LAYER_SETUP = layer_setup

        cand_cfg = _to_dict(super_cfg)
        del cand_cfg['MODEL']['BACKBONE']['LAYER_SEARCH']
        with open(os.path.join(self.cfg.OUTPUT_DIR, os.path.basename(template)).replace('.yaml','_cand{}.yaml'.format(cand)), 'w') as f:
            f.writelines(safe_dump(cand_cfg))

        super_weight = self.supernet_state_dict
        cand_weight = OrderedDict()
        cand_keys = ['layers.{}.ops.{}'.format(i, c) for i, c in enumerate(paths)]

        for key, val in super_weight.items():
            if 'ops' in key:
                for ck in cand_keys:
                    if ck in key:
                        cand_weight[key.replace(ck,ck.split('.ops.')[0])] = val
            else:
                cand_weight[key] = val

        torch.save({'model':cand_weight}, os.path.join(self.cfg.OUTPUT_DIR, 'init_cand{}.pth'.format(cand)))
