
import os

import sys

sys.path.insert(0,

                os.path.abspath(os.path.join(os.path.dirname(__file__),

                                             '')))

import argparse

import pandas as pd

import json

import numpy as np

from models import gpt, opensource, hf_inference

from utils import textprocessing

from utils.clustering import clustering

from utils.clustering.clustering import tqdm_joblib

import joblib

from joblib import Parallel, delayed

from utils.clustering import lexical_diversity

from utils.clustering.ast_processing import AllSubtreeAnalysis, AstSubTree, parallel_subtree_analysis

from dataclasses import dataclass

import yaml

from tqdm import tqdm

from functools import partial

import datetime

import traceback



import signal

import traceback

































def _handler(pipe, sig, frame):

    print("")

    traceback.print_exc()

    if "" not in args.model:  

        pipe.stop_service()

        pipe.remove_service()

    raise RuntimeError("")  



from functools import partial



def partial_handler(pipe):

    return partial(_handler, pipe)









@dataclass

class Arguments:

    path_to_dataset: str = ''

    experiment_output_root: str = ''

    model: str = ''

    template: str = ''

    temperature: float = 1.0

    top_p: float = 1.0

    max_length: int = 768

    num_return_sequences: int = 10

    repetition_penalty: float = 1.0

    parallel_samples: int = 5

    port: int = 9999

    devices_list: str = ''

    startup_timeout: int = 600

    generation_timeout: int = 100

    volume: str = ''

    path_to_hf_token: str = None

    batch_size: int = None

    

    

template_dir = os.path.join(os.path.dirname(__file__), "")

templates = [f for f in os.listdir(template_dir) if f.endswith('')]

template_names = [f.split('')[0] for f in templates] + [None]

    

def readin_template(template_arg): 

    assert template_arg in template_names, f"Template {template_arg} not found. Available templates: {template_names}"

    if template_arg == None: 

        return ""

    else: 

        path_to_template = os.path.join(os.path.dirname(__file__), f"../prompt_templates/{template_arg}.txt")

        with open(path_to_template, '') as file:

            template = file.read()

        return template

    



def _format_template(prompt, template: str): 

    formatted_prompt = template.replace("", prompt)

    assert "" not in formatted_prompt, ""

    return formatted_prompt





def load_arguments_from_yaml(yaml_file):

    with open(yaml_file, '') as file:

        args_dict = yaml.safe_load(file)

    return Arguments(**args_dict)





if __name__ == '':

    path_to_yaml = sys.argv[1]

    args = load_arguments_from_yaml(path_to_yaml)

    

    

    prompt_template = readin_template(args.template)

    format_template_fun = partial(_format_template, template=prompt_template)

    

    model_name_clean = args.model.replace("", "")

    experiment_string = f"{model_name_clean}_temp_{args.temperature}_top_p_{args.top_p}_max_length_{args.max_length}_num_return_sequences_{args.num_return_sequences}_repetition_penalty_{args.repetition_penalty}_{args.template}_{datetime.datetime.now().strftime('')}"

    experiment_output_dir = os.path.join(args.experiment_output_root, experiment_string)

    os.makedirs(experiment_output_dir, exist_ok=False) 

    

    

    if '' in args.model or '' in args.model:

        pipe = gpt.GPTModel(model_name=args.model)

    else:

        

        with open(args.path_to_hf_token, "") as f:

            hf_key = f.read().strip()

        pipe = hf_inference.HFInferenceService(model_name=args.model, 

                                                parallel_samples=max(args.parallel_samples,args.num_return_sequences),

                                                port=args.port,

                                                devices_list=args.devices_list,

                                                volume=args.volume,

                                                startup_timeout=args.startup_timeout,

                                                generation_timeout=args.generation_timeout,

                                                hf_key=hf_key)

        sigint_handler = partial_handler(pipe)

        signal.signal(signal.SIGINT, sigint_handler)

    try:                                                



        

        print(f'reading in data from {args.path_to_dataset}')

        df = pd.read_json(args.path_to_dataset, lines=True, orient='')

        

        



        

        client, image = clustering.build_docker_image(clustering.clustering_abs_dir)



        results = []

        count = 0

        

        for index, row in tqdm(df.iterrows()):

            if index > 100:

                break

            

            result = {}

            result[''] = args.model

            result[''] = index

            

            prompt = row['']

            problem_id = row['']

            extract_arguments_fun = row[""]

            

            

            formatted_prompt = format_template_fun(prompt)

            

            

            generateds_program = pipe.generate(

                formatted_prompt, 

                temperature=args.temperature,

                num_return_sequences=args.num_return_sequences,

                

                max_length=args.max_length,

                do_sample=True, 

                return_dict_in_generate=True, 

                batch_size=args.batch_size,

            )

            

            programs = [textprocessing.extract_python_code(g) for g in generateds_program]

            formatted_programs = [clustering.format_open_ended_code(program, extract_arguments_fun) for program in programs]

            

    

            

            

            

            result[''] = prompt

            result[''] = programs

            result[''] = formatted_programs

            testcase_inputs = row['']

            result[''] = testcase_inputs

            



            

        

            

            

            

            

            

            

            

            

            

            with tqdm_joblib(tqdm(desc="", total=len(formatted_programs))) as progress_bar:

                output_records = Parallel(n_jobs=10, backend='')(delayed(clustering.instrument_code_docker)(

                    formatted_program, 

                    testcase_inputs, 

                    None, 

                    image, 

                    client,

                    n_test_cases=-1, 

                    indiv_tc_timeout=60, 

                    verbose_docker=True) for formatted_program in formatted_programs if formatted_program is not None)

            

            result[''] = output_records

            coherent_records = clustering.get_coherent_records(output_records)

            incoherent_records = clustering.get_incoherent_records(output_records)

            result[''] = coherent_records

            result[''] = incoherent_records

            

            recordtype_2_records = {

                "": output_records, 

                 "": coherent_records, 

                 "": incoherent_records

            }

                

            for recordtype, records in recordtype_2_records.items():

                

                if type(records) is not list:

                    records = [records]

                coherences = clustering.get_coherence(records, strict=False)

                avg_coherence = np.mean([coherence == 1.0 for coherence in coherences])

                result[f'{recordtype}_coherence'] = avg_coherence



                

                program_2_semantic_string, semantic_strings_2_programs = clustering.make_semantic_strings(records)

                semantic_count = len(semantic_strings_2_programs.keys())

                print('', semantic_count)

                result[f'{recordtype}_semantic_count'] = semantic_count

                result[f'{recordtype}_semantic_proportion'] = semantic_count / len(records) if len(records) > 0 else np.nan



                result[f'{recordtype}_program_2_semantic_string'] = program_2_semantic_string

                result[f'{recordtype}_semantic_strings_2_programs'] = semantic_strings_2_programs



                

                

                

                programs = [program for program in programs if program is not None] 

                

                if len([p for p in programs if len(p) > 0]) > 2:

                    

                    import tokenize

                    try: 

                        distinct_1 = lexical_diversity.distinct_n(programs, 1, lexical_diversity.get_relevant_tokens_parso)

                        distinct_2 = lexical_diversity.distinct_n(programs, 2, lexical_diversity.get_relevant_tokens_parso)

                        distinct_3 = lexical_diversity.distinct_n(programs, 3, lexical_diversity.get_relevant_tokens_parso)

                        distinct_4 = lexical_diversity.distinct_n(programs, 4, lexical_diversity.get_relevant_tokens_parso)

                        distinct_5 = lexical_diversity.distinct_n(programs, 5, lexical_diversity.get_relevant_tokens_parso)

                        distinct_6 = lexical_diversity.distinct_n(programs, 6, lexical_diversity.get_relevant_tokens_parso)

                    except tokenize.TokenError as e:

                        import pdb; pdb.set_trace()

                    corpus_self_bleu = lexical_diversity.parallel_corpus_self_bleu(programs, lexical_diversity.get_relevant_tokens_parso, n_jobs=8, normalize=True)

                    result[f'{recordtype}_distinct_1'] = distinct_1

                    result[f'{recordtype}_distinct_2'] = distinct_2

                    result[f'{recordtype}_distinct_3'] = distinct_3

                    result[f'{recordtype}_distinct_4'] = distinct_4

                    result[f'{recordtype}_distinct_5'] = distinct_5

                    result[f'{recordtype}_distinct_6'] = distinct_6

                    result[f'{recordtype}_corpus_self_bleu'] = corpus_self_bleu

                    parallel_subtree_results = parallel_subtree_analysis(programs, n_jobs=8, heights=[3,4,5,6])

                    for key, height_results in parallel_subtree_results.items():

                        for height, v in height_results.items():

                            result[f"{recordtype}_{key}_{height}"] = v

                else:

                    result[f'{recordtype}_distinct_1'] = np.nan

                    result[f'{recordtype}_distinct_2'] = np.nan

                    result[f'{recordtype}_distinct_3'] = np.nan

                    result[f'{recordtype}_distinct_4'] = np.nan

                    result[f'{recordtype}_distinct_5'] = np.nan

                    result[f'{recordtype}_distinct_6'] = np.nan

                    result[f'{recordtype}_corpus_self_bleu'] = np.nan

                    for key in ['', '', '']:

                        for height in [3,4,5,6]:

                            result[f"{recordtype}_{key}_{height}"] = np.nan

                                                                                   

                                                                                   

                

                if recordtype == '':

                    problem_id_dir = os.path.join(experiment_output_dir, f'problem_{problem_id}')   

                    os.makedirs(problem_id_dir, exist_ok=False)                 

                    for i, (generation, program, formatted_program, output_record, coherence) in enumerate(zip(generateds_program, programs, formatted_programs, output_records, coherences)):

                        with open(os.path.join(problem_id_dir, f'gen_{i}_coh_{coherence}.txt'), '') as f:

                            f.write(generation)

                        with open(os.path.join(problem_id_dir, f'prog_{i}_coh_{coherence}.txt'), '') as f:

                            f.write(program)

                        with open(os.path.join(problem_id_dir, f'formatted_prog_{i}_coh_{coherence}.txt'), '') as f:

                            f.write(formatted_program)

                        with open(os.path.join(problem_id_dir, f'output_record_{i}_coh_{coherence}.json'), '') as f:

                            f.write(json.dumps(output_record))  

                    

            

            

            

            

            report_keys = ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '']

            keys = [f"{recordtype}_{key}" for recordtype in ['', '', ''] for key in report_keys]

            

            with open(os.path.join(problem_id_dir, f'result.tsv'), '') as f:

                for k in keys:

                    f.write(f"{k}\t{result[k]}\n")

                

            

            

            

            results.append(result)

            if count % 10 == 0:

                

                

                with open(os.path.join(experiment_output_dir, ''), '') as f:

                    for result in results:

                        f.write(json.dumps(result) + '')

            count += 1

            

            



        

        with open(os.path.join(experiment_output_dir, ''), '') as f:

        

            for result in results:

                f.write(json.dumps(result) + '')

        

        

        df_results = pd.DataFrame(results)

        results_stats_keys = ['', '', '', '', '', '', '', '', '']

        results_stats_keys = results_stats_keys + [f"{key}_{height}" for key in ['', '', ''] for height in [3,4,5,6]]

        results_stats_keys = [f"{recordtype}_{key}" for recordtype in ['', '', ''] for key in results_stats_keys]

        df_results_stats = df_results[results_stats_keys]

        

        

        

        described = df_results_stats.apply(lambda x: x.dropna().describe())

        print(described)

        

        

        

        

        described.to_csv(os.path.join(experiment_output_dir, ''), sep='')

        

        mean = described.loc['']

        with open(os.path.join(experiment_output_dir, ''), '') as f:

            for k, v in mean.items():

                f.write(f"{k}\t{v}\n")

                

        if "" not in args.model:

            pipe.stop_service()

            pipe.remove_service()

        

        print('')

        

    except Exception as e:

        traceback.print_exc()

        if "" not in args.model:

            pipe.stop_service()

            pipe.remove_service()

        raise e

