import os
import sys
import time
import glob
import numpy as np
import pickle
import torch
import logging
import argparse
import torch
import random
import torch.nn.functional as F
from collections import OrderedDict
from copy import deepcopy

import sys
sys.setrecursionlimit(10000)
import argparse

import functools
print = functools.partial(print, flush=True)


choice = lambda x: x[np.random.randint(len(x))] if isinstance(
    x, tuple) else choice(tuple(x))


class EvolutionSearcher(object):

    def __init__(self, args, eval_cand_fn, prompt_opt, constraint_fn, thresh):
        self.args = args
        self.constraint_fn = constraint_fn
        self.thresh = thresh

        ## EA hyper-params
        self.select_num = args.select_num
        self.population_num = args.population_num
        self.m_prob = args.m_prob
        self.crossover_num = args.crossover_num
        self.mutation_num = args.mutation_num

        self.vis_dict = OrderedDict()  #### wrc comment: keep all information about visited candidates
        self.keep_top_k = {self.select_num: [], self.population_num: []}
        self.epoch = 0
        self.candidates = []

        self.domain = prompt_opt.domain
        self.cand_choices = self.generate_cand_choices()
        self.eval_cand_fn = eval_cand_fn
        
        ## initial distribution to generate population
        self.prompt_opt = prompt_opt
        self.hard_sample = self.args.hard_sample if 'hard_sample' in self.args else None
        self.first_argmax = self.args.first_argmax if 'first_argmax' in self.args else None
        self.gumbel_samples = args.gumbel_samples
        self.explore = args.explore if 'explore' in args else 1.0
        self.explore_inv = args.explore_inv if 'explore_inv' in args else 0
        self.max_iter_exceeded = False

    def generate_cand_choices(self):
        cand_choices = []
        for _type in self.domain:
            sub_dict = self.domain[_type]
            for sub_word_dict in sub_dict['all_words']:
                do_sub = sub_word_dict['do_sub']
                if do_sub:
                    sub_words = sub_word_dict['sub_words']
                    cand_choices.append(len(sub_words))
                    print(_type, sub_words)
        return cand_choices

    def sample_from_gumbel(self):
        cand = self.prompt_opt.sample_prompt(argmax=self.first_argmax and self.first_gumbel_sample, return_ids=True)
        self.first_gumbel_sample = False
        return cand

    def sample_from_uniform(self):
        cand = []
        for cand_choice in self.cand_choices:
            cand.append(np.random.randint(0, cand_choice))
        return cand

    def sample_cand_fn(self):
        explore = np.random.uniform() < self.explore
        if explore:
            cand = self.sample_from_uniform()
        else:
            cand = self.sample_from_gumbel()
        return tuple(cand)
    
    def cand_id2prompt(self, cand):
        """ core function of transforming DNA (in this file) to prompt (in main.py) """
        all_prompt = {'pos':None, 'neg':None}
        idx = 0
        for _type in self.domain:
            sub_dict = self.domain[_type]

            prompt = []
            for sub_word_dict in sub_dict['all_words']:
                if sub_word_dict['do_sub']:
                    sub_words = sub_word_dict['sub_words']
                    prompt.append(sub_words[cand[idx]]); idx += 1
                else:
                    sub_word = sub_word_dict['ori_word']
                    prompt.append(sub_word)

            prompt = ' '.join(prompt)
            if _type == 'pos':
                all_prompt['pos'] = prompt
            elif _type == 'neg':
                ## TODO under dev ##
                all_prompt['neg'] = prompt.replace(' ', ',')
            else:
                raise ValueError(_type)
        
        if all_prompt['pos'] is None:
            all_prompt['pos'] = self.prompt_opt.ori_prompt
        assert idx == len(cand)
        return all_prompt['pos'], all_prompt['neg']

    def validate_and_eval(self, cand):
        assert isinstance(cand, tuple)
        if cand not in self.vis_dict:
            self.vis_dict[cand] = {}
        info = self.vis_dict[cand]
        if 'visited' in info:
            return False

        info['visited'] = True

        ## valid cands
        info['cand_prompt'] = self.cand_id2prompt(cand)
        info['constraint'] = self.constraint_fn(info['cand_prompt'][0])
        if info['constraint'] < self.thresh:
            info['avg_loss'] = 1000
            return False
        avg_loss, loss_list = self.eval_cand_fn(*info['cand_prompt']); self.budget += 1
        info['avg_loss'] = avg_loss
        info['loss_list'] = loss_list

        return True

    def update_top_k(self, candidates, *, k, key, reverse=False):
        assert k in self.keep_top_k
        print('select ......')
        t = self.keep_top_k[k]
        ori_t = deepcopy(t)
        t += candidates
        t.sort(key=key, reverse=reverse)
        self.keep_top_k[k] = t[:k]

        if t[:k] != ori_t[:k]:
            logging.info(f'---> Top {k} updated')
        if len(ori_t) > 0 and t[0] != ori_t[0]:
            logging.info(f'---> Best updated')

    def stack_random_cand(self, random_func, *, batchsize=10):
        while True:
            cands = [random_func() for _ in range(batchsize)]
            for cand in cands:
                if cand not in self.vis_dict:
                    self.vis_dict[cand] = {}
                info = self.vis_dict[cand]
            for cand in cands:
                yield cand

    def get_random(self, num):
        print('random select ........')
        max_iters = num * 10

        cand_iter = self.stack_random_cand(self.sample_cand_fn)
        while len(self.candidates) < num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)
            if not self.validate_and_eval(cand):
                continue
            self.candidates.append(cand)
            print('random {}/{}'.format(len(self.candidates), num))
        print('random_num = {}'.format(len(self.candidates)))
        if max_iters == 0:
            self.max_iter_exceeded = True

    def get_mutation(self, k, mutation_num, m_prob):
        assert k in self.keep_top_k
        print('mutation ......')
        res = []
        max_iters = mutation_num * 10

        def random_select_and_mutate_func():  ## randomly pick a candidate and mutate
            cand = list(choice(self.keep_top_k[k]))
            for i in range(len(cand)):
                if np.random.random_sample() < m_prob:
                    cand[i] = np.random.randint(self.cand_choices[i])
            return tuple(cand)

        cand_iter = self.stack_random_cand(random_select_and_mutate_func)
        while len(res) < mutation_num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)
            if not self.validate_and_eval(cand):
                continue
            res.append(cand)
        
        print('mutation_num = {}'.format(len(res)))
        if max_iters == 0:
            self.max_iter_exceeded = True
        return res

    def get_crossover(self, k, crossover_num):
        assert k in self.keep_top_k
        print('crossover ......')
        res = []
        max_iters = 10 * crossover_num

        def random_parent_crossover_func():
            p1 = choice(self.keep_top_k[k])
            p2 = choice(self.keep_top_k[k])
            return tuple(choice([i, j]) for i, j in zip(p1, p2))

        cand_iter = self.stack_random_cand(random_parent_crossover_func)
        while len(res) < crossover_num and max_iters > 0:
            max_iters -= 1
            cand = next(cand_iter)
            if not self.validate_and_eval(cand):
                continue
            res.append(cand)
            # print('crossover {}/{}'.format(len(res), crossover_num))

        print('crossover_num = {}'.format(len(res)))
        if max_iters == 0:
            self.max_iter_exceeded = True
        return res

    def get_topk(self, k):
        topk_cands = []
        for cand in self.keep_top_k[k]:
            topk_cands.append(self.vis_dict[cand])
        return topk_cands

    def search(self):
        logging.info('population_num = {} select_num = {} mutation_num = {} crossover_num = {} random_num = {} budget = {}'.format(
            self.population_num, self.select_num, self.mutation_num, self.crossover_num, self.population_num - self.mutation_num - self.crossover_num, self.gumbel_samples))
    
        self.budget = 0
        self.first_gumbel_sample = True

        #### init population
        self.get_random(self.population_num)

        #### search
        while self.budget < self.gumbel_samples and not self.max_iter_exceeded:  ## the first epoch is random init
            logging.info('epoch = {}'.format(self.epoch))

            ## register top k
            self.update_top_k(self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]['avg_loss'])
            self.update_top_k(self.candidates, k=self.population_num, key=lambda x: self.vis_dict[x]['avg_loss'])

            logging.info('epoch = {} : top {} result'.format(self.epoch, len(self.keep_top_k[self.population_num])))
            for i, cand in enumerate(self.keep_top_k[self.population_num]):
                logging.info('No.{} {} Top-1 err = {}'.format(i + 1, cand, self.vis_dict[cand]['avg_loss']))
                ops = [i for i in cand]
                logging.info(ops)

            ## skip the last mutation crossover (redundant runs)
            mutation = self.get_mutation(self.select_num, self.mutation_num, self.m_prob)
            if self.budget >= self.gumbel_samples:
                break
            crossover = self.get_crossover(self.select_num, self.crossover_num)
            self.candidates = mutation + crossover

            self.get_random(self.population_num)

            self.epoch += 1

        self.update_top_k(self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]['avg_loss'])
        self.update_top_k(self.candidates, k=self.population_num, key=lambda x: self.vis_dict[x]['avg_loss'])
        return self.get_topk(k=self.population_num)

    def get_cand_history(self):
        cand_history = []
        for cand in self.vis_dict:
            info = self.vis_dict[cand]
            if 'avg_loss' in info:
                cand_history.append(info)
        return cand_history
