import itertools
import numpy as np
from functools import partial
from tot.models import gpt

def get_value(task, x, y, n_evaluate_sample, cache_value=True):
    value_prompt, y_correct = task.value_prompt_wrap(x, y)
    if value_prompt is None:
        return 0, None
    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt], y_correct
    value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None, task=task)
    value = task.value_outputs_unwrap(x, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value
    return value, y_correct

def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
    values = []
    local_value_cache = {}
    for index, y in enumerate(ys):  # each partial output, e.g., y = '4 + 8 = 12 (left: 4 6 12)\n'
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:    
            value, y_correct = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
            if y_correct: # log all the time
                task.df_replacements = task.df_replacements._append({'Input': x, 'Steps': y, 'Corrected-steps': y_correct}, ignore_index=True)
                if task.replace_incorrect_results: 
                    ys[index] = y_correct
                    local_value_cache[y_correct] = value
            else: # nothing needs to be done if y_correct is None
                local_value_cache[y] = value
        values.append(value)
    return values

def get_votes(task, x, ys, n_evaluate_sample):
    vote_prompt = task.vote_prompt_wrap(x, ys)
    vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None, task=task)
    values = task.vote_outputs_unwrap(vote_outputs, len(ys))
    return values

def get_proposals(task, x, y): 
    propose_prompt = task.propose_prompt_wrap(x, y)
    # Initially y is empty, so it's just the input x (4 numbers); next time y would have a value like '4 + 8 = 12 (left: 4 6 12)\n'
    # In the next step, the propose prompt will have the 'left' part of the last line of y as the input (e.g., '4 6 12')
    proposals = gpt(propose_prompt, n=1, stop=None, propose_prompt_flag=True, x_y_pair=(x,y), task=task)[0].split('\n')
    return [y + _ + '\n' for _ in proposals]

def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
    if prompt_sample == 'standard':
        prompt = task.standard_prompt_wrap(x, y)
    elif prompt_sample == 'cot':
        prompt = task.cot_prompt_wrap(x, y)
    else:
        raise ValueError(f'prompt_sample {prompt_sample} not recognized')
    samples = gpt(prompt, n=n_generate_sample, stop=stop, task=task)
    return [y + _ for _ in samples]

def solve(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    print(gpt)
    x = task.get_input(idx)  # input
    ys = ['']  # current output candidates
    infos = []
    for step in range(task.steps):
        # generation
        if args.method_generate == 'sample':
            new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
        elif args.method_generate == 'propose': # starts with adding the input to the propose prompt, asking GPT and new_ys will have all the n proposals
            new_ys = [get_proposals(task, x, y) for y in ys] # by 2nd step this will have 5*n entries, and will be clipped to best 5 below
        new_ys = list(itertools.chain(*new_ys))
        ids = list(range(len(new_ys)))
        # evaluation
        if args.method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, args.n_evaluate_sample)
        elif args.method_evaluate == 'value':
            values = get_values(task, x, new_ys, args.n_evaluate_sample)
            # TODO: check new_ys have been replaced with corrected values

        # selection
        if args.method_select == 'sample':
            ps = np.array(values) / sum(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        elif args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids]

        # log
        if to_print: 
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
            print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
        
        infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
        ys = select_new_ys
    
    if to_print: 
        print(ys)
    return ys, {'steps': infos}

def naive_solve(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    print(gpt)
    x = task.get_input(idx)  # input
    ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
    return ys, {}