import itertools
import numpy as np
from functools import partial
from tot.models import gpt

def get_value(task, x, f, y, n_evaluate_sample, cache_value=True):
    value_prompt = task.value_prompt_wrap(x, f, y)
    # print('value_prompt')
    # print(value_prompt)
    
    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt]
        
    value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
    value = task.value_outputs_unwrap(x, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value
    return value

def get_values(task, x, f, ys, n_evaluate_sample, cache_value=True):
    values = []
    local_value_cache = {}
    for y in ys:  # each partial output
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:    
            value = get_value(task, x, f, y, n_evaluate_sample, cache_value=cache_value)
            local_value_cache[y] = value
        values.append(value)
    return values

def get_value_gsm8k(task, q, y, n_evaluate_sample, cache_value=True):
    value_prompt = task.value_prompt_wrap(q, y)
    
    if cache_value and value_prompt in task.value_cache:
        return task.value_cache[value_prompt]
        
    value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
    value = task.value_outputs_unwrap(q, y, value_outputs)
    if cache_value:
        task.value_cache[value_prompt] = value
    return value

def get_values_gsm8k(task, q, ys, n_evaluate_sample, cache_value=True):
    values = []
    local_value_cache = {}
    for y in ys:  # each partial output
        if y in local_value_cache:  # avoid duplicate candidates
            value = 0
        else:    
            value = get_value_gsm8k(task, q, y, n_evaluate_sample, cache_value=cache_value)
            local_value_cache[y] = value
        values.append(value)
    return values

def get_votes(task, x, ys, n_evaluate_sample):
    vote_prompt = task.vote_prompt_wrap(x, ys)
    vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
    values = task.vote_outputs_unwrap(vote_outputs, len(ys))
    return values

def get_proposals(task, x, f, y): 
    propose_prompt = task.propose_prompt_wrap(x, f, y)
    # print('get proposal')
    proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    # print('got proposal')
    # print(proposals)
    sam = [y + _ + '\n' for _ in proposals]
    sam = [i.strip() for i in sam]
    # print(sam)
    return sam

def get_proposals_gsm8k(task, q, y): 
    propose_prompt = task.propose_prompt_wrap(q, y)
    # print('get proposal')
    proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    # print('got proposal')
    # print(proposals)
    sam = [y + _ + '\n' for _ in proposals]
    sam = [i.strip() for i in sam]
    # print(sam)
    return sam

def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
    if prompt_sample == 'standard':
        prompt = task.standard_prompt_wrap(x, y)
    elif prompt_sample == 'cot':
        prompt = task.cot_prompt_wrap(x, y)
    else:
        raise ValueError(f'prompt_sample {prompt_sample} not recognized')
    samples = gpt(prompt, n=n_generate_sample, stop=stop)
    return [y + _ for _ in samples]

def solve(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    print(gpt)
    x = task.get_input(idx)  # input
    ys = ['']  # current output candidates
    infos = []
    for step in range(task.steps):
        # generation
        # print(step, 'generation')
        if args.method_generate == 'sample':
            new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
        elif args.method_generate == 'propose':
            new_ys = [get_proposals(task, x, y) for y in ys]
            
        new_ys = list(itertools.chain(*new_ys))
        ids = list(range(len(new_ys)))
        
        # evaluation
        # print(step, 'evaluation')
        if args.method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, args.n_evaluate_sample)
        elif args.method_evaluate == 'value':
            values = get_values(task, x, new_ys, args.n_evaluate_sample)

        # selection
        # print(step, 'selection')
        if args.method_select == 'sample':
            ps = np.array(values) / sum(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        elif args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids]
        

        # log
        if to_print: 
            sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
            print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
        
        infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
        ys = select_new_ys
    
    if to_print: 
        print(ys)
    return ys, {'steps': infos}

def solve_new(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    # print(gpt)
    x, f = task.get_input(idx)  # input
    ys = ['']  # current output candidates
    infos = []
    
    for step in range(task.steps):
        # generation
        # print(step, 'generation')
        if args.method_generate == 'sample':
            new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
            
        elif args.method_generate == 'propose':
            new_ys = [get_proposals(task, x, f, y) for y in ys]
            
        new_ys = list(itertools.chain(*new_ys))
        ids = list(range(len(new_ys)))
        new_ys = [i+'\n' for i in new_ys]

        # print('new_ys')
        # print(new_ys)
        
        # evaluation
        # print(step, 'evaluation')
        if args.method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, args.n_evaluate_sample)
        elif args.method_evaluate == 'value':
            values = get_values(task, x, f, new_ys, args.n_evaluate_sample)

        # selection
        # print(step, 'selection')
        if args.method_select == 'sample':
            ps = np.array(values) / sum(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        elif args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids]

        # print('select_new_ys')
        # print(select_new_ys)

        # log
        # if to_print: 
        #     sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
        #     print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
        
        infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
        ys = select_new_ys
    
    # if to_print: 
    #     print(ys)
    return ys, {'steps': infos}


def solve_new_gsm8k(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    # print(gpt)
    q, ins, fa = task.get_input(idx)  # input
    ys = ['']  # current output candidates
    infos = []
    
    for step in range(len(ins)+1):
        # generation
        # print(step, 'generation')
        if args.method_generate == 'sample':
            new_ys = [get_samples(task, x, y, args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step]) for y in ys]
            
        elif args.method_generate == 'propose':
            new_ys = [get_proposals_gsm8k(task, q, y) for y in ys]
            
        new_ys = list(itertools.chain(*new_ys))
        ids = list(range(len(new_ys)))
        new_ys = [i+'\n' for i in new_ys]

        # print('new_ys')
        # print(new_ys)
        
        # evaluation
        # print(step, 'evaluation')
        if args.method_evaluate == 'vote':
            values = get_votes(task, x, new_ys, args.n_evaluate_sample)
        elif args.method_evaluate == 'value':
            values = get_values_gsm8k(task, q, new_ys, args.n_evaluate_sample)

        # selection
        # print(step, 'selection')
        if args.method_select == 'sample':
            ps = np.array(values) / sum(values)
            select_ids = np.random.choice(ids, size=args.n_select_sample, p=ps).tolist()
        elif args.method_select == 'greedy':
            select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
        select_new_ys = [new_ys[select_id] for select_id in select_ids]

        # print('select_new_ys')
        # print(select_new_ys)

        # log
        # if to_print: 
        #     sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
        #     print(f'-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n')
        
        # infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
        ys = select_new_ys
    
    # if to_print: 
    #     print(ys)
    return ys, {'steps': infos}


def naive_solve(args, task, idx, to_print=True):
    global gpt
    gpt = partial(gpt, model=args.backend, temperature=args.temperature)
    print(gpt)
    x = task.get_input(idx)  # input
    ys = get_samples(task, x, '', args.n_generate_sample, args.prompt_sample, stop=None)
    return ys, {}