import os
import sys
import numpy as np
import torch
import logging
import torch
import random

import tqdm
from timm import utils

# from .tester import get_cand_err
from ..models.wrappers import Supernet

import sys
sys.setrecursionlimit(10000)


def set_seeds(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


class Choice:

    def __init__(self, seed):
        self.seed = seed

    def choice(self, x, reset_seed=True):
        if reset_seed:
            # The seed will be changed every time a sample is drawn, so that the same sample can be drawn multiple times.
            self.seed += 1
            set_seeds(self.seed)
        return x[np.random.randint(len(x))] if isinstance(x, tuple) else self.choice(tuple(x), reset_seed=False)
    
    def mutate(self, m_prob, search_space, reset_seed=True):
        if reset_seed:
            self.seed += 1
            set_seeds(self.seed)
        if np.random.random_sample() < m_prob:
            cand = random.choice(search_space)
        else:
            cand = None
        return cand
    
    def __call__(self, x, reset_seed=True):
        return self.choice(x, reset_seed=reset_seed)

_logger = logging.getLogger("evoltionary_search")


class EvolutionSearcher(object):

    def __init__(self, model: Supernet, val_loader, args, amp_autocast, device):
        self.args = args

        self.log_prefix = f"Worker ID: {args.rank}"

        torch.backends.cudnn.deterministic = True

        seed = max(args.seed - 42, 0) if args.distributed else args.seed
        _logger.info(f"[{self.log_prefix}] seed = {seed}")
        set_seeds(seed)

        self.max_epochs = args.max_epochs
        self.select_num = 5 if args.unit_test else args.select_num
        self.population_num = 10 if args.unit_test else args.population_num
        self.m_prob = args.m_prob
        self.crossover_num = 5 if args.unit_test else args.crossover_num
        self.mutation_num = 5 if args.unit_test else args.mutation_num
        self.flops_limit = args.flops_limit
        self.model = model
        self.module = model.module if args.distributed else model
        self.explore_probability = args.explore_probability
        self.fitness_criteria = args.fitness_criteria

        self.device = device

        self.unit_test = args.unit_test

        self.amp_autocast = amp_autocast
        self.val_loader = val_loader

        if utils.is_primary(args):
            self.log_dir = args.ea_log_dir
            task_idx = self.module.task_idx
            self.checkpoint_name = os.path.join(self.log_dir, f"{task_idx}.pth.tar")
        # set_trace()
        self.memory = []
        self.vis_dict = {}
        self.keep_top_k = {self.select_num: [], 50: []}
        self.epoch = 0
        self.candidates = []

        self.nr_layer = len(self.module.task_to_expert_map)

        self.choice = Choice(seed)

        self.model.eval()

    def get_cand_err(self, cand, **kwargs):

        top1_m = utils.AverageMeter()
        top5_m = utils.AverageMeter()

        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] starting test....")

        with torch.no_grad():

            for data, target in tqdm.tqdm(self.val_loader):
                if not self.args.prefetcher:
                    data, target = data.to(self.device), target.to(self.device)
                if self.args.channels_last:
                    input = input.contiguous(memory_format=torch.channels_last)

                with self.amp_autocast():
                    output = self.model(data, expert_ids=cand, verbose=self.args.debug, worker_id=self.args.rank, **kwargs)

                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
                if self.args.distributed:
                    acc1 = utils.reduce_tensor(acc1, self.args.world_size)
                    acc5 = utils.reduce_tensor(acc5, self.args.world_size)

                if self.device.type == "cuda":
                    torch.cuda.synchronize()

                top1_m.update(acc1.item(), output.size(0))
                top5_m.update(acc5.item(), output.size(0))

            top1, top5 = 1 - top1_m.avg / 100, 1 - top5_m.avg / 100

        criteria = []
        print_str = []
        
        # In order of priority
        if "top1" in self.fitness_criteria:
            criteria.append(top1)
            print_str.append(f"top1: {top1*100:.2f}")
        if "top5" in self.fitness_criteria:
            criteria.append(top5)
            print_str.append(f"top5: {top5*100:.2f}")
        if "size" in self.fitness_criteria:
            added_size = self.module.arch_size(cand)
            criteria.append(added_size)
            print_str.append(f"size: {added_size}")

        _logger.info(f"[{self.log_prefix}] ".join(print_str))

        return tuple(criteria)

    def save_checkpoint(self):
        
        info = {}
        info["memory"] = self.memory
        info["candidates"] = self.candidates
        # print(self.candidates)
        info["vis_dict"] = self.vis_dict
        info["keep_top_k"] = self.keep_top_k
        if utils.is_primary(self.args):
            for cand in self.keep_top_k[self.select_num]:
                _logger.info(f"[{self.log_prefix}] {cand}: {self.vis_dict[cand]['err']}")
        info["epoch"] = self.epoch
        if utils.is_primary(self.args):
            if not os.path.exists(self.log_dir):
                os.makedirs(self.log_dir)
            torch.save(info, self.checkpoint_name)
            _logger.info(f"[{self.log_prefix}] save checkpoint to {self.checkpoint_name}")

    def is_legal(self, cand):
        assert isinstance(cand, tuple) and len(cand) == self.nr_layer
        if cand not in self.vis_dict:
            self.vis_dict[cand] = {}
        info = self.vis_dict[cand]
        if "visited" in info:
            return False

        if "flops" not in info:
            info["flops"] = 0

        if info["flops"] > self.flops_limit:
            return False

        info["err"] = self.get_cand_err(cand)

        info["visited"] = True

        return True

    def update_top_k(self, candidates, *, k, key, reverse=False):
        assert k in self.keep_top_k
        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] select ......")
        t = self.keep_top_k[k]
        t += candidates
        t.sort(key=key, reverse=reverse)
        self.keep_top_k[k] = t[:k]

    def stack_random_cand(self, random_func, *, batchsize=10):
        while True:
            cands = [random_func() for _ in range(batchsize)]
            for cand in cands:
                if cand not in self.vis_dict:
                    self.vis_dict[cand] = {}
                info = self.vis_dict[cand]
            for cand in cands:
                yield cand

    def get_random(self, num):
        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] random select ........")

        def random_op():
            if self.args.debug:
                _logger.info(f"[{self.log_prefix}] seed: {self.choice.seed}")
            if self.explore_probability < 1.:
                mode = "explore" if np.random.random_sample() < self.explore_probability else "exploit"
            else:
                mode = "explore"
            if self.args.distributed:
                self.choice.seed += 1
            # Sampler will handle setting the seed. constant_sample_per_worker=True will ensure 
            # that the same samples are generated on each worker.
            seed = self.choice.seed if self.args.distributed else None
            candidate = tuple(self.module._sample_nas_ops(seed=seed, constant_sample_per_worker=True, mode=mode))
            self.choice.seed = self.module.seed if self.args.distributed else None
            if self.args.debug:
                _logger.info(f"[{self.log_prefix}] mode: {mode}, random_op: {candidate}")
            return candidate

        # cand_iter = self.stack_random_cand(
        #     lambda: tuple(np.random.randint(self.operation_choices) for i in range(self.nr_layer)))
        cand_iter = self.stack_random_cand(random_op)
        while len(self.candidates) < num:
            cand = next(cand_iter)
            if not self.is_legal(cand):
                continue
            self.candidates.append(cand)
            if utils.is_primary(self.args):
                _logger.info(f"[{self.log_prefix}] random {len(self.candidates)}/{num}")
        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] random_num = {len(self.candidates)}")

    def get_mutation(self, k, mutation_num, m_prob):
        assert k in self.keep_top_k
        if utils.is_primary(self.args):
            _logger.info("mutation ......")
        res = []
        iter = 0
        max_iters = mutation_num * 10

        def random_func():
            cand = list(self.choice(self.keep_top_k[k], reset_seed=self.args.distributed))
            has_been_mutated = False
            for i in range(self.nr_layer):
                mutated_cand = self.choice.mutate(m_prob, self.module.nas_search_space[i], reset_seed=self.args.distributed)
                if mutated_cand is not None:
                    has_been_mutated = True
                    cand[i] = mutated_cand
            if has_been_mutated and (utils.is_primary(self.args) or self.args.debug):
                _logger.info(f"[{self.log_prefix}] Mutated candidate {cand}")
            return tuple(cand)

        cand_iter = self.stack_random_cand(random_func)
        while len(res) < mutation_num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)

            if not self.is_legal(cand):
                continue
            res.append(cand)
            if utils.is_primary(self.args):
                _logger.info(f"[{self.log_prefix}] mutation {len(res)}/{mutation_num}")
        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] mutation_num = {len(res)}")
        return res

    def get_crossover(self, k, crossover_num):
        assert k in self.keep_top_k
        if utils.is_primary(self.args):
            _logger.info("crossover ......")
        res = []
        iter = 0
        max_iters = 10 * crossover_num

        def random_func():
            p1 = self.choice(self.keep_top_k[k], reset_seed=self.args.distributed)
            p2 = self.choice(self.keep_top_k[k], reset_seed=self.args.distributed)
            return tuple(self.choice([i, j], reset_seed=self.args.distributed) for i, j in zip(p1, p2))
        cand_iter = self.stack_random_cand(random_func)
        while len(res) < crossover_num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)

            if not self.is_legal(cand):
                continue
            res.append(cand)
            if utils.is_primary(self.args) or self.args.debug:
                _logger.info(f"[{self.log_prefix}] crossover {len(res)}/{crossover_num}")
                _logger.info(f"[{self.log_prefix}] crossover candidate {cand}")

        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] crossover_num = {len(res)}")
        return res

    def search(self):
        if utils.is_primary(self.args):
            _logger.info(f"[{self.log_prefix}] population_num = {self.population_num}" + \
                        f" select_num = {self.select_num} mutation_num = {self.mutation_num}" + \
                        f" crossover_num = {self.crossover_num}" + \
                        f" random_num = {self.population_num - self.mutation_num - self.crossover_num}" + \
                        f" max_epochs = {self.max_epochs}")

        self.get_random(self.population_num)

        while self.epoch < self.max_epochs:
            if utils.is_primary(self.args):
                _logger.info(f"[{self.log_prefix}] epoch = {self.epoch}")

            self.memory.append([])
            for cand in self.candidates:
                self.memory[-1].append(cand)

            self.update_top_k(
                self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]["err"])
            self.update_top_k(
                self.candidates, k=50, key=lambda x: self.vis_dict[x]["err"])

            if utils.is_primary(self.args):
                _logger.info(f"[{self.log_prefix}] epoch = {self.epoch} : top {len(self.keep_top_k[50])} result")
                for i, cand in enumerate(self.keep_top_k[50]):
                    _logger.info(f"[{self.log_prefix}] No.{i+1} {cand} Top-1 err = {self.vis_dict[cand]['err']}")
                    ops = [i for i in cand]
                    _logger.info(ops)

            mutation = self.get_mutation(
                self.select_num, self.mutation_num, self.m_prob)
            crossover = self.get_crossover(self.select_num, self.crossover_num)

            self.candidates = mutation + crossover

            self.get_random(self.population_num)

            self.epoch += 1

            self.save_checkpoint()

            if self.unit_test:
                _logger.info(f"[{self.log_prefix}]Unit Test: Exiting after 1 epoch.")
                break
