import util
import llms, envs, algs
from templates import buttons, oldbuttons, adverts
import numpy as np
import os, pickle, time, pathlib

def print_dialogue(debug,system,prompt,pred,val):
    if debug:
        print(f"[SYSTEM] {system}", flush=True)
        print(f"[USER] {prompt}", flush=True)
        print(f"[PRED] {pred}", flush=True)
        print(f"[PARSED] {val}", flush=True)

def generate_history(t,env,rng,hist_alg):
    history = []
    counts = np.zeros(env.K)
    alg = hist_alg(env.K)
    order = np.arange(0,env.K)
    rng.shuffle(order)
    for tau in range(t):
        act = alg.get_action()
        ## override the algorithm to ensure every arm gets 1 pull
        for a in order:
            if counts[a] == 0:
                act = a
        rews = env.get_rewards()
        history.append((act,rews[act]))
        alg.update(act,rews[act])
        counts[act] += 1
    means = np.zeros(env.K)
    for i in range(env.K):
        means[i] = np.mean([x[1] for x in history if x[0] == i])
    return (history, means, counts)
            


# def generate_history(t,env,rng,hist_alg='unif'):

#     history = []
#     counts = np.zeros(env.K)
#     order = np.arange(0,env.K)
#     rng.shuffle(order)
#     if hist_type == 'unif':
#         for tau in range(t):
#             act = rng.choice(env.K)
#             for a in order:
#                 if counts[a] == 0:
#                     act = a
#             rews = env.get_rewards()
#             history.append((act,rews[act]))
#             counts[act] += 1
#     means = np.zeros(env.K)
#     for i in range(env.K):
#         means[i] = np.mean([x[1] for x in history if x[0] == i])
#     return (history, means, counts)

def convert_history(history,action_map):
    raw_history = []
    for item in history:
        raw_history.append([action_map[item[0]],item[1]])
    return (raw_history)

def puzzle_exp(t,N,env,llm,template,hist_alg,seed=None,debug=False):
    if seed is not None:
        rng = np.random.default_rng(seed)
    else:
        rng = np.random.default_rng()

    reverted_count = 0
    failed_count = 0
    output = []
    for n in range(N):
        ## Generate a history of length t 
        (history, means, counts) = generate_history(t,env,rng,hist_alg)
        action_map = template.get_outputs()
        rng.shuffle(action_map)
        raw_history = convert_history(history,action_map)
        ## Make the LLM call
        (val, rev_tmp, fail_tmp) = util.query_llm(raw_history, llm, template, debug=debug)
        if rev_tmp:
            reverted_count += 1
        if fail_tmp:
            failed_count += 1
        idx = action_map.index(val)
        output.append((means,counts,idx))
    return(output)

def baseline_exp(t,N,env,Alg,hist_alg,seed=None,debug=False):
    if seed is not None:
        rng = np.random.default_rng(seed)
    else:
        rng = np.random.default_rng()

    output = []
    for n in range(N):
        ## Generate a history of length t 
        (history, means, counts) = generate_history(t,env,rng,hist_alg)
        ## Make the algorithm call
        alg = Alg()
        for item in history:
            alg.update(item[0], item[1])
        val = alg.get_action()
        output.append((means,counts,val))
    return(output)

if __name__=='__main__':
    import sys, argparse

    parser = argparse.ArgumentParser()

    ## LLM parameters
    parser.add_argument('--llm', action='store',
                        default='gpt4',
                        help='Large language model', type=str)
    parser.add_argument('--template', action='store',
                        default='buttons',
                        help='Template for LLMs', type=str)
    parser.add_argument('--suggestive', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Suggestive of neutral framing')
    parser.add_argument('--summarized', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Summarized or raw history')
    parser.add_argument('--distribution', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Allow LLM to output a distribution')
    parser.add_argument('--cot', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Allow LLM to use CoT')
    parser.add_argument('--temp', action='store', type=str,
                        default='0')

    ## Env parameters
    parser.add_argument('--T', action='store',
                        default=40,
                        help='Episode length', type=int)
    parser.add_argument('--t', action='store',
                        default=40,
                        help='history length', type=int)
    parser.add_argument('--K', action='store',
                        default=5,
                        help='number of actions', type=int)
    parser.add_argument('--delta', action='store', type=str,
                        default='0.5')
    parser.add_argument('--history', action='store',
                        default='unif',
                        help='history generator', type=str)

    ## Replicates and debugging?
    parser.add_argument('--reps', action='store',
                        default=25,
                        help='number of replicates', type=int)
    parser.add_argument('--debug', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Log debugging information')

    hist_algs = {
        'unif': lambda k: algs.Unif(num_arms=k),
        'greedy': lambda k: algs.Greedy(num_arms=k),
        'egreedy': lambda k: algs.EGreedy(num_arms=k),
        'ts': lambda k: algs.BBThompson(num_arms=k),
        'ucb': lambda k: algs.UCB1(num_arms=k)
        }

    Args = parser.parse_args(sys.argv[1:])
    T = Args.T
    t = Args.t
    K = Args.K
    delta = float(Args.delta)
    hist_type = Args.history
    hist_alg = hist_algs[hist_type]

    reps = Args.reps
    debug = Args.debug
    llm_name = Args.llm
    temp = Args.temp
    base_dir = '../results/'
    llm = None; template = None
    if llm_name == 'gpt4':
        llm = llms.OpenAILLM('gpt-4',temp=temp)
    elif llm_name == 'gpt3' and llms.key == 'ak':
        llm = llms.OpenAILLM('gpt-3.5-turbo',temp=temp)
    elif llm_name == 'gpt3':
        llm = llms.OpenAILLM('gpt-35-turbo',temp=temp)
    elif llm_name == 'llama7b':
        llm = llms.LlamaLLM('Llama-2-7b-hf', temp=temp)
    elif llm_name == 'llama13b':
        llm = llms.LlamaLLM('Llama-2-13b-hf', temp=temp)
    elif llm_name == 'llama70b':
        llm = llms.LlamaLLM('Llama-2-70b-hf', temp=temp)
    else:
        print('[Debug] Invalid LLM name, only running baselines', flush=True)

    if Args.template == 'buttons':
        template = buttons.ButtonsPrompt(T, K, 
                                         suggestive=Args.suggestive,
                                         summarized=Args.summarized,
                                         dist=Args.distribution,
                                         cot=Args.cot)
    elif Args.template =='oldbuttons':
        template = oldbuttons.OldButtonsPrompt(T, K,
                                               suggestive=Args.suggestive,
                                               summarized=Args.summarized,
                                               cot=Args.cot)
    elif Args.template == 'adverts':
        template = adverts.AdvertsPrompt(T, K,
                                         suggestive=Args.suggestive,
                                         summarized=Args.summarized,
                                         dist=Args.distribution,
                                         cot=Args.cot)
    else:
        print("[Debug] Invalid template selected, only running baselines", flush=True)

    ps = (0.5-delta/2)*np.ones(K)
    ps[0] = 0.5+delta/2
    if template is not None and llm is not None:
        dir_name = base_dir+f"puzzles_{template.get_name()}_T={T}_t={t}_delta={delta:0.1f}_{hist_type}/"
        try:
            pathlib.Path(dir_name).mkdir(parents=True,exist_ok=True)
        except FileExistsError:
            pass
        
        fname = dir_name + f"{llm.long_name}.pkl"
        if os.path.isfile(fname):
            print("[Debug]" + fname + " Already Completed", flush=True)
        else:
            print("[Debug]" + fname + " Running", flush=True)
            env = envs.BernoulliBandit(K,ps=ps,seed=None)
            output = puzzle_exp(t,reps,env,llm,template,hist_alg,debug=debug)
            pickle.dump(output, open(fname, 'wb'))

    Baselines = {
        'ts': algs.BBThompson,
        'greedy': algs.Greedy,
        'ucb': algs.UCB1,
        'egreedy': algs.EGreedy
    }

    dir_name = base_dir + f"puzzles_baselines_K={K}_T={T}_t={t}_delta={delta:0.1f}_{hist_type}/"
    try:
        pathlib.Path(dir_name).mkdir(parents=True, exist_ok=True)
    except FileExistsError:
        pass
    for (b_name,b_alg) in Baselines.items():
        env = envs.BernoulliBandit(K,ps=ps)
        if b_name == 'ts':
            fname = dir_name + f"{b_name}.pkl"
            Alg = lambda: b_alg(num_arms=env.K)
        elif b_name == 'greedy':
            fname = dir_name + f"{b_name}_1.pkl"
            Alg = lambda: b_alg(num_arms=env.K,n0=1)
        elif b_name == 'ucb':
            fname = dir_name + f"{b_name}_1.0.pkl"
            Alg = lambda: b_alg(num_arms=env.K,eta=1.0)
        elif b_name == 'egreedy':
            fname = dir_name + f"{b_name}_0.1.pkl"
            Alg = lambda: b_alg(num_arms=env.K,eps=0.1)
        if os.path.isfile(fname):
            print("[Debug]" + fname + " Already Completed", flush=True)
        else:
            print("[Debug]" + fname + " Running", flush=True)
            output = baseline_exp(t,reps,env,Alg,hist_alg)
            pickle.dump(output, open(fname, "wb"))
            
