
import os

import sys

sys.path.insert(0,

                os.path.abspath(os.path.join(os.path.dirname(__file__),

                                             '')))

import argparse

import pandas as pd

import json

import numpy as np

from models import gpt, claude, opensource, hf_inference

from utils import textprocessing

from utils.clustering import clustering

from utils.clustering.clustering import tqdm_joblib

import joblib

from joblib import Parallel, delayed

from utils.clustering import lexical_diversity

from utils.clustering.ast_processing import AllSubtreeAnalysis, AstSubTree, parallel_subtree_analysis

from dataclasses import dataclass

import yaml

from tqdm import tqdm

from functools import partial

import datetime

import traceback

import logging

from transformers import AutoTokenizer



import signal

import traceback

import tiktoken 







def handler(pipe, experiment_output_dir, signum, frame):

    print("")

    try:

        print("")

        with open(os.path.join(experiment_output_dir, ''), '') as f:

            f.write("")

        pipe.stop_service()

        pipe.remove_service()

        print("")

    except Exception as e:

        print("", e)

        traceback.print_exc()

    finally:

        sys.exit(0)  



from functools import partial

import sys 



sys.path.append(os.path.dirname(__file__))



from async_driver import Arguments, load_arguments_from_yaml























































    

    

template_dir = os.path.join(os.path.dirname(__file__), "")

templates = [f for f in os.listdir(template_dir) if f.endswith('')]

template_names = [f.split('')[0] for f in templates] + [None]

    

def readin_template(template_arg): 

    assert template_arg in template_names, f"Template {template_arg} not found. Available templates: {template_names}"

    if template_arg == None: 

        return ""

    else: 

        path_to_template = os.path.join(os.path.dirname(__file__), f"../prompt_templates/{template_arg}.txt")

        with open(path_to_template, '') as file:

            template = file.read()

        return template

    



def _format_template(prompt, template: str): 

    formatted_prompt = template.replace("", prompt)

    assert "" not in formatted_prompt, ""

    return formatted_prompt

















if __name__ == '':

    path_to_yaml = sys.argv[1]

    args = load_arguments_from_yaml(path_to_yaml)

    

    

    prompt_template = readin_template(args.template)

    format_template_fun = partial(_format_template, template=prompt_template)

    

    model_name_clean = args.model.replace("", "")

    

    

    experiment_name = args.experiment_name

    experiment_output_dir = args.experiment_output_dir

    is_directed = args.is_directed

    

    os.makedirs(experiment_output_dir, exist_ok=True) 

    

    logs_file = os.path.join(experiment_output_dir, '')

    

    logging.basicConfig(level=logging.INFO, 

                        handlers=[

                            logging.FileHandler(logs_file),  

                            logging.StreamHandler(sys.stdout)  

                        ], 

                        force=True

    )

    logging.info(f"Starting generations for {experiment_name}")

    

    

    with open(os.path.join(experiment_output_dir, ''), '') as f:

        yaml.dump(args.__dict__, f)



    

    

    

    

    

    

        

    pipe = None 

    try:                                                

        

        if '' in args.model or '' in args.model or '' in args.model:

            pipe = gpt.GPTModel(model_name=args.model)

            tokenizer = tiktoken.encoding_for_model(args.model)

            

        elif any([model in args.model for model in ['', '', '', ""]]):

            pipe = claude.ClaudeModel(model_name=args.model)

            tokenizer = tiktoken.encoding_for_model("")

            

        else:

            

            with open(args.path_to_hf_token, "") as f:

                hf_key = f.read().strip()

            logging.info(f"Starting HF Inference Service with model {args.model}")

            pipe = hf_inference.HFInferenceService(model_name=args.model, 

                                                    parallel_samples=max(args.parallel_samples,args.num_return_sequences),

                                                    port=args.port,

                                                    devices_list=args.devices_list,

                                                    volume=args.volume,

                                                    startup_timeout=args.startup_timeout,

                                                    generation_timeout=args.generation_timeout,

                                                    hf_key=hf_key)

            tokenizer = AutoTokenizer.from_pretrained(args.model, token=hf_key)

            

            sigint_handler = partial(handler, pipe, experiment_output_dir)

            signal.signal(signal.SIGINT, sigint_handler)

                                             

        

        print(f'reading in data from {args.path_to_dataset}')

        df = pd.read_json(args.path_to_dataset, lines=True, orient='')

        

        

        

        

        

        

        

            

        

        if args.max_programs > 0:

            logging.info(f"Limiting to {args.max_programs} programs")

            df = df.iloc[:args.max_programs]



        results = []

        count = 0

        times = []

        start = datetime.datetime.now()

        

        for index, row in tqdm(df.iterrows()):

            this_start = datetime.datetime.now()

            logging.info(f"Generating for index {index}")

            

            result = {}

            result[''] = args.model

            result[''] = index

            

            prompt = row['']

            problem_id = row[''] if not is_directed else row[""]

            

            extract_arguments_fun = row[""] if not is_directed else None

            

            

            result.update(row.to_dict())

            

            

            formatted_prompt = format_template_fun(prompt)

            result[''] = formatted_prompt

            if "" in args.model.lower():

                

                n_prompt_tokens = len(tokenizer(formatted_prompt)[''])

                max_tokens = min(2048 - n_prompt_tokens, args.max_length - 32) 

                logging.info(f"Tulu-2 model, max tokens: {max_tokens}")

                

            elif "" in args.model.lower():

                n_prompt_tokens = len(tokenizer(formatted_prompt)[''])

                max_tokens = min(4096 - n_prompt_tokens, args.max_length - 32) 

                logging.info(f"Codellama model, max tokens: {max_tokens}")

            else: 

                max_tokens = args.max_length

            

            

            

            raw_generations = pipe.generate(

                formatted_prompt, 

                max_new_tokens=max_tokens,

                num_samples=args.num_return_sequences,

                temperature=args.temperature,

                do_sample=True, 

                top_p=args.top_p,

                top_k=None,

                return_dict_in_generate=False, 

                batch_size=args.batch_size,

            )

            

            programs = [textprocessing.extract_python_code(g) for g in raw_generations]

            if is_directed: 

                formatted_programs = [clustering.format_directed_code(program) for program in programs]

            else: 

                formatted_programs = [clustering.format_open_ended_code(program, extract_arguments_fun) for program in programs]

    

            result[""] = raw_generations

            result[''] = prompt

            result[''] = programs

            result[''] = formatted_programs

            testcase_inputs = row['']

            result[''] = testcase_inputs

            result[''] = problem_id

            result[''] = extract_arguments_fun

            result[''] = prompt

                                  

            problem_id_dir = os.path.join(experiment_output_dir, f'problem_{problem_id}')   

            problem_id_gen_dir = os.path.join(problem_id_dir, '')

            os.makedirs(problem_id_gen_dir, exist_ok=True)  

            for i, (generation, program, formatted_program) in enumerate(zip(raw_generations, programs, formatted_programs)):

                with open(os.path.join(problem_id_gen_dir, f'gen_{i}_coh_.txt'), '') as f:

                    f.write(generation)

                with open(os.path.join(problem_id_gen_dir, f'prog_{i}_coh.txt'), '') as f:

                    f.write(program)

                with open(os.path.join(problem_id_gen_dir, f'formatted_prog_{i}.txt'), '') as f:

                    f.write(formatted_program)

                    

            results.append(result)

            count += 1

            this_end = datetime.datetime.now()

            run_elapsed = this_end - this_start

            times.append(run_elapsed)

            logging.info(f"Finished index {index} in {run_elapsed}")

            

        end = datetime.datetime.now()

        total_elapsed = end - start

        

        logging.info(f"Finished all in {total_elapsed}")



        

        logging.info("")

        

        

        

        

        pd.DataFrame(results).to_json(os.path.join(experiment_output_dir, ''), orient='', lines=True)

        logging.info("")

            

        if not any(model in args.model for model in ['', '', '']) and args.model not in ['', '', '']:

            pipe.stop_service()

            pipe.remove_service()

        

        print(f"Done generating for {experiment_name} in {total_elapsed}")

        

    except Exception as e:

        

        traceback_str = traceback.format_exc()

        with open(os.path.join(experiment_output_dir, ''), '') as f:

            f.write("")

            f.write(traceback_str)

        logging.error(f"Error during generation: {traceback_str}")

        if not any(model in args.model for model in ['', '', '']) and args.model not in ['', '', ''] and pipe is not None:

            logging.info("")

            pipe.stop_and_remove_if_running()

            logging.info("")

        raise e

