import numpy as np
import os, pickle, time
import llms, envs, algs, util
from templates import buttons, oldbuttons, adverts
import pathlib

def exp(T,env,llm,template,seed=None,debug=False):
    scale = template.get_reward_scale()
    arm_map = template.get_outputs()
    if seed is None:
        np.random.shuffle(arm_map)
    else:
        rng = np.random.default_rng(seed)
        rng.shuffle(arm_map)

    raw_hist = []
    hist = []
    reverted_count = 0
    failed = False
    for t in range(T):
        (val, rev_tmp, fail_tmp) = util.query_llm(raw_hist, llm, template, debug=debug)
        if rev_tmp:
            reverted_count += 1
        if fail_tmp:
            failed=True
            return(hist, reverted_count, failed)            
        rewards = env.get_rewards()
        idx = arm_map.index(val)
        raw_hist.append((val, rewards[idx], rewards))
        hist.append((idx, rewards[idx], rewards))
        print(f'[Debug] Done with round {t+1}', flush=True)
    return(hist, reverted_count, failed)

def bandit_exp(T,env,Alg,n0=1,eta=1,eps=0.1):
    if Alg==algs.UCB1:
        alg = Alg(num_arms=env.K, eta=eta)
    elif Alg==algs.Greedy:
        alg = Alg(num_arms=env.K, n0=n0)
    elif Alg==algs.EGreedy:
        alg = Alg(num_arms=env.K, eps=eps)
    else:
        alg = Alg(num_arms=env.K)

    hist = []
    for t in range(T):
        action = alg.get_action()
        rewards = env.get_rewards()
        alg.update(action, rewards[action])
        hist.append([action, rewards[action], rewards])
    return (hist, alg)
        


if __name__=='__main__':
    import sys, argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--template', action='store',
                        default='buttons',
                        help='Template for LLMs', type=str)
    parser.add_argument('--suggestive', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Suggestive of neutral framing')
    parser.add_argument('--summarized', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Summarized or raw history')
    parser.add_argument('--distribution', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Allow LLM to output a distribution')
    parser.add_argument('--cot', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Allow LLM to use CoT')
    parser.add_argument('--T', action='store',
                        default=40,
                        help='number of rounds', type=int)
    parser.add_argument('--K', action='store',
                        default=5,
                        help='number of actions', type=int)
    parser.add_argument('--reps', action='store',
                        default=25,
                        help='number of replicates', type=int)
    parser.add_argument('--llm', action='store',
                        default='gpt4',
                        help='Large language model', type=str)
    parser.add_argument('--debug', action=argparse.BooleanOptionalAction,
                        default=False,
                        help='Log debugging information')
    parser.add_argument('--delta', action='store', type=str,
                        default='0.5')
    parser.add_argument('--temp', action='store', type=str,
                        default='0')
    parser.add_argument('--n0', action='store', type=int,
                        default=1)
    parser.add_argument('--eta', action='store', type=float,
                        default=1)
    parser.add_argument('--eps', action='store', type=float,
                        default=0.1)

    Args = parser.parse_args(sys.argv[1:])
    n0 = Args.n0
    eta = Args.eta
    eps = Args.eps
    T = Args.T
    K = Args.K
    reps = Args.reps
    llm_name = Args.llm
    debug = Args.debug
    delta_choice = Args.delta
    temp = Args.temp
    base_dir = '../results/'
    if llm_name == 'gpt4':
        llm = llms.OpenAILLM('gpt-4',temp=temp)
    elif llm_name == 'gpt3':
        llm = llms.OpenAILLM('gpt-3.5-turbo',temp=temp)
    elif llm_name == 'llama7b':
        llm = llms.LlamaLLM('Llama-2-7b-hf', temp=temp)
    elif llm_name == 'llama13b':
        llm = llms.LlamaLLM('Llama-2-13b-hf', temp=temp)
    elif llm_name == 'llama70b':
        llm = llms.LlamaLLM('Llama-2-70b-hf', temp=temp)
    else:
        llm = None
        print('[Debug] Invalid LLM name, only running baselines', flush=True)

    if Args.template == 'buttons':
        template = buttons.ButtonsPrompt(T, K, 
                                         suggestive=Args.suggestive,
                                         summarized=Args.summarized,
                                         dist=Args.distribution,
                                         cot=Args.cot)
    elif Args.template =='oldbuttons':
        template = oldbuttons.OldButtonsPrompt(T, K,
                                               suggestive=Args.suggestive,
                                               summarized=Args.summarized,
                                               cot=Args.cot)
    elif Args.template == 'adverts':
        template = adverts.AdvertsPrompt(T, K,
                                         suggestive=Args.suggestive,
                                         summarized=Args.summarized,
                                         dist=Args.distribution,
                                         cot=Args.cot)
        
    else:
        template = None
        print("[Debug] Invalid template selected, only running baselines", flush=True)

    if K > 7:
        print("[Debug] Only K<=7 supported by templates", flush=True)
        sys.exit(1)
        
    Baselines = {
        'ts': algs.BBThompson,
        'greedy': algs.Greedy,
        'ucb': algs.UCB1,
        'egreedy': algs.EGreedy
    }

    if delta_choice == '0.2':
        deltas = [0.2]
    elif delta_choice == '0.4':
        deltas = [0.4]
    elif delta_choice == '0.5':
        deltas = [0.5]
    elif delta_choice == '0.6':
        deltas = [0.6]
    else:
        deltas = [0.2, 0.4]

    for delta in deltas:
        ps = (0.5-delta/2)*np.ones(K)
        ps[0] = 0.5+delta/2
        if template is not None and llm is not None:
            dir_name = base_dir + f"{template.get_name()}_T={T}_delta={delta:0.1f}/"
            try:
                pathlib.Path(dir_name).mkdir(parents=True, exist_ok=True)
            except FileExistsError:
                pass

            for rep in range(reps):
                env = envs.BernoulliBandit(K,ps=ps,seed=rep)
                fname = dir_name + f"{llm.long_name}_{rep}.pkl"
                if os.path.isfile(fname):
                    print("[Debug]" + fname + " Already Completed", flush=True)
                else:
                    print("[Debug]" + fname + " Running", flush=True)
                    output = exp(T,env,llm,template,seed=rep,debug=debug)
                    pickle.dump(output, open(fname,'wb'))

        dir_name = base_dir + f"baselines_K={K}_T={T}_delta={delta:0.1f}/"
        try:
            pathlib.Path(dir_name).mkdir(parents=True, exist_ok=True)
        except FileExistsError:
            pass

        for (b_name,b_alg) in Baselines.items():
            for rep in range(reps):
                env = envs.BernoulliBandit(K,ps=ps,seed=rep)
                ## TODO: Modularize later
                if b_name == 'ts':
                    fname = dir_name + f"{b_name}_{rep}.pkl"
                elif b_name == 'greedy':
                    fname = dir_name + f"{b_name}_{n0}_{rep}.pkl"
                elif b_name == 'ucb':
                    fname = dir_name + f"{b_name}_{eta:0.1f}_{rep}.pkl"
                elif b_name == 'egreedy':
                    fname = dir_name + f"{b_name}_{eps:0.2f}_{rep}.pkl"
                if os.path.isfile(fname):
                    print("[Debug]" + fname + " Already Completed", flush=True)
                    pass
                else:
                    print("[Debug]" + fname + " Running", flush=True)
                    (output,_) = bandit_exp(T,env,b_alg,n0,eta,eps)
                    pickle.dump(output, open(fname, "wb"))
