#from gpt_utils import query_gpt
#from claude_utils import query_claude
from model_utils import query_claude, parallel_query_gpt
import argparse
import os
import pickle
from experiment import *
from utils import get_outdir, get_git_commit

# Set environment variables

EXPERIMENT_DICT = {
    'write-vulnerable-code': (WriteVulnerableCodeExperiment, 2000),
    'write-secure-code': (WriteSecureCodeExperiment, 2000),
    'write-vulnerable-code-from-secure': (WriteVulnerableCodeFromSecureExperiment, 2000),
    'evaluate-vulnerable-code': (EvaluateVulnerableCodeExperiment, 50),
    'generate-python-script': (GeneratePythonScriptExperiment, 1000),
    'generate-related-tasks': (GenerateRelatedTasksExperiment, 1000),
    'evaluate-hacking': (EvaluateHackingExperiment, 50),
    'personalized-tweet-disinfo': (PersonalizedDisinfoTweetsExperiment, 4000),
    'generate-related-disinfo': (GenerateRelatedDisinfoTasksExperiment, 1000),
    'evaluate-disinformation': (EvaluateDisinformationExperiment, 200),
}

def handle_model_names(args):
    for attr in ['model', 'output_model', 'reference_model']:
        if not hasattr(args, attr):
            continue
        att_val = getattr(args, attr)
        if att_val is None: continue
        if att_val == 'gpt-4-turbo':
            setattr(args, attr, 'gpt-4-0125-preview')
        if att_val == 'gpt-4':
            setattr(args, attr, 'gpt-4-0613')
        if att_val == 'gpt-3.5-turbo':
            setattr(args, attr, 'gpt-3.5-turbo-0125')
        if att_val == 'claude':
            setattr(args, attr, 'claude-2.1')
        if att_val == 'claude-3' or att_val == 'claude-3-opus':
            setattr(args, attr, 'claude-3-opus-20240229')
        if att_val == 'claude-3-sonnet':
            setattr(args, attr, 'claude-3-sonnet-20240229')
        if att_val == 'claude-3-haiku':
             setattr(args, attr, 'claude-3-haiku-20240307')
        if att_val == 'mistral':
            setattr(args, attr, 'mistral-7B-instruct-v0.2')
        if att_val == 'mixtral':
            setattr(args, attr, 'mixtral-8x7b-instruct-v0.1')
    return args

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='gpt-3.5-turbo')
    parser.add_argument('--outdir', type=str, default='experiment_results')
    parser.add_argument('--experiment_name', type=str, required = True)
    parser.add_argument('--trial', type=int, default=0)
    parser.add_argument('--n_examples', type=int, default=None)
    parser.add_argument('--sanity_check', action='store_true')
    parser.add_argument('--temperature', type=float, default=0.01)
    parser.add_argument('--overwrite', action='store_true')
    # Args for experiments where we build on previous models
    parser.add_argument('--output_model', type = str, default = None)
    # Args for experiment where we're doing function inversion
    parser.add_argument('--difficulty', type = str, default = 'hard')
    parser.add_argument('--use_documentation', action = 'store_true')
    parser.add_argument('--modification_type', type = str, default = None)
    # Args for the vulnerability addition experiments (will also be used for hacking...)
    parser.add_argument('--reference_model', type = str, default = None)
    # Args for the disinformation experiment
    parser.add_argument('--n_users', type = int, default = 1,
        help = 'Number of users to generate disinformation for simultaneously')
    args = parser.parse_args()
    args = handle_model_names(args)
    args.git_sha = get_git_commit()
    args.finished = False
    return args

def convert_to_logfile(outdir):
    log_dir = os.path.join(outdir, 'log')
    if not os.path.exists(log_dir) and os.path.exists(outdir):
        os.makedirs(log_dir)
    return os.path.join(log_dir, 'log.pkl') 

def save_data(data, outdir, args = False):
    if args:
        filename = 'args.pkl'
        data = vars(data)
    else:
        filename = 'results.pkl'
    with open(os.path.join(outdir, filename), 'wb') as f:
        pickle.dump(data, f)

def get_outputs(prompts, log_outfile, args, max_tokens = 10, model = None):
    if args.model.startswith('gpt'):
        #outputs = query_gpt(prompts, log_outfile, model = args.model, max_tokens = max_tokens, temperature = args.temperature, overwrite = args.overwrite)
        outputs = parallel_query_gpt(prompts, log_outfile, model = args.model, max_tokens = max_tokens, temperature = args.temperature, overwrite = args.overwrite)
    elif args.model.startswith('llama') and args.model.endswith('chat'):
        #from llama_vllm_utils import query_llama, load_llama_model
        from model_utils import query_llama_vllm as query_llama
        from model_utils import load_llama_model_vllm as load_llama_model
        if 'llama-7B-chat' not in args.model and 'llama-13B-chat' not in args.model and 'llama-70B-chat' not in args.model:
            raise NotImplementedError
        if model is None:
            model = load_llama_model(args.model)
        outputs = query_llama(prompts, model, max_tokens = max_tokens, temperature = args.temperature)
    elif args.model.startswith('llama'):
        from model_utils import query_llama, load_llama_model
        if model is None:
            model, tokenizer = load_llama_model(args.model)
        outputs = query_llama(prompts, model, tokenizer, max_tokens = max_tokens, temperature = args.temperature)
    elif args.model.startswith('claude'):
        if args.model not in ['claude-2.1', 'claude-3-opus-20240229', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307']:
            raise NotImplementedError
        outputs = query_claude(prompts, log_outfile, model = args.model, max_tokens = max_tokens, temperature = args.temperature, overwrite = args.overwrite)
    elif args.model.startswith('mistral') or args.model.startswith('mixtral'):
        from model_utils import query_mistral_vllm as query_mistral
        from model_utils import load_mistral_model_vllm as load_mistral_model
        model = load_mistral_model(args.model)
        outputs = query_mistral(prompts, model, args.model, max_tokens = max_tokens, temperature = args.temperature)
    elif args.model.startswith('gemma'):
        from model_utils import query_gemma_vllm as query_gemma
        from model_utils import load_gemma_model_vllm as load_gemma
        #from model_utils import query_gemma
        #from model_utils import load_gemma
        model = load_gemma(args.model)
        outputs = query_gemma(prompts, model, args.model, max_tokens = max_tokens, temperature = args.temperature)
    else:
        raise NotImplementedError
    return outputs

def main(args, experiment):
    outdir = get_outdir(args)
    log_outfile = convert_to_logfile(outdir)
    if not args.sanity_check:
        save_data(args, outdir, args = True)
    prompts, metadata = experiment.get_prompts()
    if args.n_examples is not None and len(prompts) < args.n_examples:
        print(f'WARNING: Not enough examples: {len(prompts)} < {args.n_examples}')
    if args.n_examples is not None:
        if args.experiment_name == 'personalized-tweet-disinfo' and args.output_model is not None and args.reference_model is None:
            prompts = prompts[:args.n_examples * 3]
            metadata = metadata[:args.n_examples * 3]
        else:
            prompts = prompts[:args.n_examples]
            metadata = metadata[:args.n_examples]
    if args.sanity_check:
        idx = 0
        print("SYSTEM")
        print(prompts[idx][0]['content'])
        print("USER")
        print(prompts[idx][1]['content'])
        print(outdir)
        #print(metadata[idx])
        print("Total prompts: ", len(prompts))
        assert False
    max_tokens = EXPERIMENT_DICT[args.experiment_name][1]
    if 'llama' in args.model:
        max_tokens = max([max_tokens, 50])
    outputs = get_outputs(prompts, log_outfile, args, max_tokens = max_tokens)
    args.finished = True
    save_data((prompts, outputs, metadata), outdir, args = False)
    save_data(args, outdir, args = True)


def get_experiment(args, experiment_name):
    return experiment


if __name__ == '__main__':
    args = parse_args()
    experiment = EXPERIMENT_DICT[args.experiment_name][0](args)
    main(args, experiment)