import json
import os
import numpy as np
import heapq
import random

from utils import setup_log, k_init_pop
from utils import (
    read_lines,
    get_final_prompt,
    extract_numbers,
)
from llm_client import paraphrase, llm_query
from data.template_ga import templates_2
from data.templates import *
from run_bbh import eval_task
import functools


class Evoluter:
    def __init__(self, args, llm_config, client):
        self.init_poplulation = []
        self.population = []
        self.scores = []
        self.marks = []
        self.client, self.llm_config = client, llm_config
        self.public_out_path = args.output
        self.task = args.task
        self.task_prompt = open('lib_prompt/%s.txt' % self.task, 'r').read()

        self.logger = logger = setup_log(os.path.join(self.public_out_path, f"evol.log"))
        logger.info("=" * 50)
        logger.info("\n\t" + "\n\t".join(f"{k} = {v}" for k, v in vars(args).items()))
        logger.info("=" * 50)
        self.args = args
        self.out_path = os.path.join(self.public_out_path, f"dev_result.txt")
        self.task_data = json.load(open('data/%s.json' % args.task))['examples']
        sample_num = args.sample_num
        self.dev_data = random.sample(self.task_data,sample_num)
        self.test_data = [i for i in self.task_data if i not in self.dev_data]

        model = 'turbo' if 'turbo' in args.llm_type else 'davinci'

        self.eval_func = functools.partial(eval_task, task=self.task, task_prompt=self.task_prompt, eval_data=self.dev_data, client=client, model_index=model,logger=logger, demon=args.demon, consis=args.consis, **llm_config)  

    def sorted(self):
        best_score = 0
        total_score = 0
        with open(os.path.join(self.public_out_path, 'dev_result.txt'), "w") as wf:
            self.scores, self.population, self.marks = (
                list(t)
                for t in zip(
                    *sorted(
                        zip(self.scores, self.population, self.marks),
                        key=lambda x: x[0],
                        reverse=True,
                    )
                )
            )
            for score, prompt, mark in zip(self.scores, self.population, self.marks):
                float_score = float(score)
                if float_score > best_score:
                    best_score = float_score
                total_score += float_score
                wf.write(f"{mark}\t{prompt}\t{score}\n")
            wf.write(f"best score: {best_score}\n")
            wf.write(f"average score: {total_score / len(self.scores)}\n")
            wf.close()

    def run(self):
        self.evolute()
        self.sorted()

    def evolute(self):
        raise NotImplementedError


class GAEvoluter(Evoluter):
    def __init__(self, args, llm_config,client):
        super(GAEvoluter, self).__init__(args, llm_config=llm_config, client=client)
    
    def init_pop(self):
        args = self.args
        logger = self.logger
        prompts2mark = {}
        evaluated_prompts = {}
        out_path = self.public_out_path
        cur_budget = -1
        cache_path = args.cache_path
        if args.initial == 'ckpt':
            init_population = []
            logger.info(f"------------load from file {args.ckpt_pop}------------")
            ckpt_pop = read_lines(args.ckpt_pop)[: args.popsize]
            for line in ckpt_pop:
                try:
                    mark, prompt, score = line.strip().split("\t")
                    score = float(score) 
                except:
                    continue
                prompts2mark[prompt] = mark
                evaluated_prompts[prompt] = score
                init_population.append(prompt) 
                cur_budget = extract_numbers(args.ckpt_pop.split("/")[-1])
            logger.info("current budget: %d" % cur_budget)
        else:
            try:
                evaluated_prompts = json.load(open(cache_path, "r"))
                logger.info(f"---loading prompts from {cache_path}")
                evaluated_prompts = dict(
                    sorted(
                        evaluated_prompts.items(),
                        key=lambda item: item[1],
                        reverse=True,
                    )
                )
                init_population = [k for k in list(evaluated_prompts.keys())]
            except:
                topk_population = []
                evaluated_prompts = {}
                pop = read_lines("prompts.txt")
                logger.info(
                    "-----evaluating initial population and paraphrasing topk---------"
                )
                for prompt in pop:
                    eval_res = self.eval_func(cot_prompt= prompt)
                    evaluated_prompts[prompt] = eval_res
                    topk_population.append((eval_res, prompt))
                topk_population.sort(reverse=True, key=lambda x: x[0])

                with open(cache_path, "w") as wf:
                    evaluated_prompts = dict(
                        sorted(evaluated_prompts.items(), key=lambda item: item[1])
                    )
                    json.dump(evaluated_prompts, wf)
                init_population = [i[1] for i in topk_population]
        prompts2mark = {i: "manual" for i in init_population}
        self.template = templates_2[args.template]['sim'] 

        # test LLM client
        _ = paraphrase(
            sentence="Hi, I am a student.",
            type=args.llm_type,
            client=self.client,
            temperature=0.5,
            **self.llm_config,
        )
        logger.info("test LLM client success")
        if args.initial_mode in ["para_topk", "para_bottomk", "para_randomk"]:
            k_pop = k_init_pop(args.initial_mode, init_population, k=args.popsize)
            para_population = paraphrase(
                client=self.client, sentence=k_pop, type=args.llm_type, temperature=0.5,**self.llm_config
            )
            for i in para_population:
                prompts2mark[i] = "para"
            init_population = k_pop + para_population
            # print(init_population)
            init_population = init_population[: args.popsize]
        elif args.initial_mode in ["topk", "bottomk", "randomk"]:
            init_population = k_init_pop(args.initial_mode, init_population, k=args.popsize)
        cur_best_score = 0
        cur_best_prompt = ""
        self.population = [i for i in init_population]
        assert len(self.population) == args.popsize
        with open(os.path.join(out_path, "step0_pop_para.txt"), "w") as wf:
            for i in init_population:
                if i not in evaluated_prompts:
                    init_scores = self.eval_func(cot_prompt= i)
                    evaluated_prompts[i] = init_scores
                scores = evaluated_prompts[i]
                if cur_best_score < scores:
                    cur_best_score = scores
                    cur_best_prompt = i
                wf.write(
                    f"{prompts2mark[i]}\t{i}\t{scores}\n"
            )
        with open(f'{out_path}/init.txt', 'w') as wf:
                    for i in self.population:
                        logger.info(i)
                        wf.write(f'{i}\n')
        return evaluated_prompts, prompts2mark, cur_budget

    def evolute(self):
        logger = self.logger
        args = self.args
        evaluated_prompts, prompts2mark, cur_budget = self.init_pop()
        out_path = self.public_out_path
        template = self.template

        best_scores = []
        avg_scores = []
        prompts = []
        marks = []
        scores = []

        cur_best_prompt, cur_best_score = max(evaluated_prompts.items(), key=lambda x: x[1])
        population = self.population
        fitness = np.array([evaluated_prompts[i] for i in population])

        for i in range(cur_budget + 1, args.budget):
            total_score = 0
            best_score = 0
            fitness = np.array([evaluated_prompts[i] for i in population])
            new_pop = []
            if args.sel_mode == 'wheel':
                wheel_idx = np.random.choice(np.arange(args.popsize), size=args.popsize, replace=True,
                        p=fitness/fitness.sum()).tolist()  # 选择的新种群
                parent_pop = [population[i] for i in wheel_idx]
            elif args.sel_mode in ['random', 'tour']:
                parent_pop = [i for i in population]
            # if args.ga_mode == 'random':
            #     parent_pop = [population[i] for i in range(len(population))]
            for j in range(args.popsize):
                logger.info("step {i}, pop {j}".format(i=i, j=j))
                if args.sel_mode in ['random', 'wheel']:
                    parents = random.sample(parent_pop, 2)
                    cand_a = parents[0]
                    cand_b = parents[1]
                elif args.sel_mode == 'tour':
                    group_a = random.sample(parent_pop, 2)
                    group_b = random.sample(parent_pop, 2)
                    cand_a = max(group_a, key=lambda x: evaluated_prompts[x])
                    cand_b = max(group_b, key =lambda x: evaluated_prompts[x])

                request_content = (
                    template.replace("<prompt1>", cand_a)
                            .replace("<prompt2>", cand_b)
                )
                # logger.info(f"old_child: {old_prompt}, {old_score}")

                logger.info("evolution example:")
                logger.info(request_content)
                logger.info("parents:")
                logger.info(cand_a)
                logger.info(cand_b)
                child_prompt = llm_query(
                    client=self.client,
                    data=request_content,
                    type=args.llm_type,
                    task=False,
                    temperature=0.5,
                    **self.llm_config,
                )
                logger.info(f"original child prompt: {child_prompt}")
                child_prompt = get_final_prompt(child_prompt)
                logger.info(f"child prompt: {child_prompt}")

                de_eval_res = self.eval_func(cot_prompt=child_prompt)
                logger.info(f"new score: {de_eval_res}")
                prompts2mark[child_prompt] = "evoluted"

                evaluated_prompts[child_prompt] = de_eval_res
                if args.ga_mode == 'std':
                    selected_prompt = child_prompt
                    selected_score = de_eval_res
                    population[j] = selected_prompt

                elif args.ga_mode == 'topk':
                    selected_prompt = child_prompt
                    selected_score = de_eval_res

                new_pop.append(selected_prompt)
                total_score += selected_score
                if selected_score > best_score:
                    best_score = selected_score

            # population = new_pop
            if args.ga_mode == 'topk':
                double_pop = list(set(population + new_pop))
                double_pop = sorted(double_pop, key=lambda x: evaluated_prompts[x], reverse=True)
                population = double_pop[:args.popsize]
                total_score = sum([evaluated_prompts[i] for i in population])
                best_score = evaluated_prompts[population[0]]
            avg_score = total_score / args.popsize
            avg_scores.append(avg_score)
            best_scores.append(best_score)

            with open(os.path.join(out_path, f"step{i}_pop.txt"), "w") as wf:
                for p in population:
                    score = evaluated_prompts[p]
                    wf.write(f'{prompts2mark[p]}\t{p}\t{score}\n')
                wf.write(f"best score: {best_score}\n")
                wf.write(f"average score: {avg_score}\n")

            if i == args.budget - 1:
                logger.info(f"----------testing step{i} population----------")
                pop_marks = [prompts2mark[i] for i in population]
                pop_scores = [evaluated_prompts[i] for i in population]
                population, pop_scores, pop_marks = (
                    list(t)
                    for t in zip(
                        *sorted(
                            zip(population, pop_scores, pop_marks),
                            key=lambda x: x[1],
                            reverse=True,
                        )
                    )
                )

                test_prompt_num = args.popsize // 2
                with open(os.path.join(out_path, f"step{i}_pop_test.txt"), "w") as wf:
                    for i in range(test_prompt_num):
                        test_prompt = population[i]
                        test_mark = pop_marks[i]
                        test_score = self.eval_func(cot_prompt= test_prompt, eval_data=self.test_data)
                        dev_score = evaluated_prompts[test_prompt]
                        all_score = (test_score * len(self.test_data) + len(self.dev_data) * evaluated_prompts[test_prompt])/ len(self.task_data)
                        wf.write(f'{test_mark}\t{test_prompt}\t{dev_score}\t{test_score}\t{all_score}\n')
                        wf.flush()

        best_scores = [str(i) for i in best_scores]
        avg_scores = [str(round(i, 4)) for i in avg_scores]
        logger.info(f"best_scores: {','.join(best_scores)}")
        logger.info(f"avg_scores: {','.join(avg_scores)}")
        self.scores = [evaluated_prompts[i] for i in population]
        self.marks = [prompts2mark[i] for i in population]
        self.population = [i for i in population]
        self.sorted()