import itertools
import numpy as np
from functools import partial
from tot.models import gpt
from concurrent.futures import ThreadPoolExecutor
import pandas as pd

def get_value(task, x, y, n_evaluate_sample, cache_value=True, local_gpt=None):
    value_prompt = task.value_prompt_wrap(x, y)
    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt]
    value_outputs = local_gpt(value_prompt, n=n_evaluate_sample, stop=None)
    value = task.value_outputs_unwrap(x, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value
    return value

'''
def get_values(task, x, ys, n_evaluate_sample, cache_value=True, local_gpt=None):
    values = []
    local_value_cache = {}
    for y in ys:  # each partial output
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:
            value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value, local_gpt=local_gpt)
            local_value_cache[y] = value
        values.append(value)
    return values
'''

def get_values(task, x, ys, n_evaluate_sample, max_workers=1, cache_value=True, local_gpt=None):
    values = []
    local_value_cache = {}

    def compute(y):
        value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value, local_gpt=local_gpt)
        return y, value

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(compute, ys))

    # Update the cache and results in the same order as ys
    for y, value in results:
        if y not in local_value_cache:
            local_value_cache[y] = value
            values.append(value)
        else:
            values.append(0)
    return values

def get_votes(task, x, ys, n_evaluate_sample, local_gpt):
    vote_prompt = task.vote_prompt_wrap(x, ys)
    vote_outputs = local_gpt(vote_prompt, n=n_evaluate_sample, stop=None)
    values = task.vote_outputs_unwrap(vote_outputs, len(ys))
    return values

def get_proposals(task, x, y, local_gpt): 
    propose_prompt = task.propose_prompt_wrap(x, y)
    proposals = local_gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    return [y + _ + '\n' for _ in proposals]

def get_samples(task, x, y, n_generate_sample, prompt_sample, stop, local_gpt):
    if prompt_sample == 'standard':
        prompt = task.standard_prompt_wrap(x, y)
    elif prompt_sample == 'cot':
        prompt = task.cot_prompt_wrap(x, y)
    else:
        raise ValueError(f'prompt_sample {prompt_sample} not recognized')
    samples = local_gpt(prompt, n=n_generate_sample, stop=stop)
    return [y + _ for _ in samples]

def verify_ys(task, x, new_ys, args, local_gpt):
    non_final = [y for y in new_ys if "answer" not in y.lower()]
    prompts = pd.DataFrame([task.get_verification_prompts(x, y) for y in non_final])
    cols = prompts.columns
    prompts["new_y"] = non_final
    prompts_long = prompts.melt(id_vars="new_y", value_vars=cols, var_name="prompt_name", value_name="convo")
    
    with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        prompts_long["outputs"] = list(executor.map(lambda y: local_gpt(y, n=args.n_verify_sample), prompts_long.convo))
    
    prompts_long["is_correct"] = prompts_long.outputs.apply(task.verification_outputs_unwrap)
    
    v = prompts_long.groupby("new_y").is_correct.all().reset_index()
    v = v[v.is_correct]
    r = v.new_y.tolist()
    # add the new_y that had "answer" in them
    r = [y for y in new_ys if "answer" in y.lower()] + r
    return r

def solve(args, task, idx, to_print=False, client=None):
    local_gpt = partial(gpt, client=client, model=args.backend, temperature=args.temperature)
    x = task.get_input(idx)  # input
    ys = ['']  # current output candidates
    infos = []
    node_cache = set()
    for step in range(task.steps):
        node_cache.update(ys)
        # generation
        if args.method_generate == 'sample':
            new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
        elif args.method_generate == 'propose':
            with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
                new_ys = list(executor.map(lambda y: get_proposals(task, x, y, local_gpt=local_gpt), ys))
        new_ys = list(itertools.chain(*new_ys))
        proposed_ys = new_ys

        if args.do_verify:
            new_ys = verify_ys(task, x, new_ys, args, local_gpt)

        if len(new_ys) == 0:
            print(idx, x, "RAN OUT OF PROPOSALS")
            ### now we need to backtrack
            all_ys = list(itertools.chain(*[k["new_ys"] for k in infos]))
            all_vals = list(itertools.chain(*[k["values"] for k in infos]))
            filtered = [(y,v) for y,v in zip(all_ys, all_vals) if y not in node_cache]
            if len(filtered) > 0:
                # these are the next highest ys not yet explored:
                topn = sorted(filtered, key=lambda x: x[1], reverse=True)[:args.n_select_sample]
                select_new_ys = [y for y,v in topn]
                values = [v for l,v in topn]
            else:
                print(idx, x,"RAN OUT OF NODES TO VISIT, RUNNING FROM ROOT")
                #topn = sorted([(y,v) for y,v in zip(all_ys, all_vals)], key=lambda x: x[1], reverse=True)[:args.n_select_sample]
                select_new_ys = [""]
                values = [-1]
            infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': [], 'values': values, 
                          'select_new_ys': select_new_ys, "proposed_ys":proposed_ys})
            ys = select_new_ys
            continue
                
        ids = list(range(len(new_ys)))
        # evaluation
        if args.method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, args.n_evaluate_sample, local_gpt=local_gpt)
        elif args.method_evaluate == 'value':
            values = get_values(task, x, new_ys, args.n_evaluate_sample, local_gpt=local_gpt, max_workers=args.max_workers)

        # selection
        if args.method_select == 'sample':
            ps = np.array(values) / sum(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        elif args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids]

        # log
        if to_print: 
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
            print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
        
        infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys, "proposed_ys":proposed_ys})
        ys = select_new_ys

    # in case the last step did not result in exactly n_select_sample ys, we fill with the highest valued nodes found during search
    all_nodes = list(itertools.chain(*[list(zip(k["new_ys"], k["values"])) for k in infos]))
    all_nodes = [x for x in all_nodes if not any(x[0] in a for a in ys)]
    patch = sorted(all_nodes, key=lambda x: x[1], reverse=True)[:(args.n_select_sample - len(ys))]
    ys += [p[0] for p in patch]
    
    # if everything fails, fill up with the root node
    if len(ys) < args.n_select_sample:
        ys += [""] * (args.n_select_sample - len(ys))
    if to_print: 
        print(ys)
    return ys, {'steps': infos}