import json
from openai import OpenAI
import pdb
from utils import Example, tokenize_code
import argparse 
import re, os
from tqdm import tqdm
from codebleu import calc_codebleu


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='conala', help='Dataset to use')
    parser.add_argument('--model', type=str, default='gpt-4o', help='Model to use')
    parser.add_argument('--task', type=str, default='code2comment', help='Task to perform')
    parser.add_argument('--data_num', type=int, default=-1, help='Number of data to process')
    return parser.parse_args()

args = parse_args()
gen_repre_first_template = open(f'prompt_template/{args.task}/gen_representation_first.txt', 'r').read()
gen_repre_iterate_template = open(f'prompt_template/{args.task}/gen_representation_iterate.txt', 'r').read()
reconstruct_template = open(f'prompt_template/{args.task}/reconstruct_code.txt', 'r').read()
form_eval_template = open(f'prompt_template/{args.task}/form_evaluation.txt', 'r').read()
form_score_pattern = r'-?\d+\.\d+'

def request_response(message, model_name: str = 'gpt-4o', sampling_size: int = 1, temperature: float = 1.0, top_p: float = 0.95):
    if model_name == 'gpt-4o':
        client = OpenAI(base_url="", api_key="")
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    completion = client.chat.completions.create(
        model=model_name,
        messages=message,
        temperature=temperature,
        top_p=top_p,
        n=sampling_size
    )
    
    completion = completion.choices[0].message
    return completion


def load_data(dataset_name):
    language_dict = {'conala': 'Python'}
    language = language_dict[dataset_name]
    dataset = []
    if dataset_name == 'conala':
        idx = 0
        with open('conala/conala-paired-test.json', 'r') as f:
            for line in f:
                js = json.loads(line)
                if js['snippet'] is None or js['rewritten_intent'] is None:
                    continue
                if js['snippet'].strip() == '' or js['rewritten_intent'].strip() == '':
                    continue
                dataset.append(Example(idx=idx, nl=js['rewritten_intent'], code=js['snippet']))
                idx += 1
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    if args.data_num > 0:
        dataset = dataset[:args.data_num]
    return language, dataset


def load_finished_data(file_name):
    if not os.path.exists(file_name):
        return []
    finished_idx = []
    with open(file_name, 'r') as f:
        for line in f:
            js = json.loads(line)
            finished_idx.append(js['idx'])
    return finished_idx


def compute_form_score(representation, language):
    form_prompt = form_eval_template.format(representation=representation, language=language)
    message = [{"role": "user", "content": form_prompt}]
    attempt = 0
    while True:
        response = request_response(message, model_name=args.model, temperature=0.0).content
        response = re.search(form_score_pattern, response)
        if response is not None:
            break
        attempt += 1
        if attempt >= 10:
            break
    if response is None:
        return None
    else:
        return float(response.group())


def extract_code(response):
    if '```' in response:
        lines = response.split('\n')
        code_lines = []
        sol, eol = None, None
        for idx, line in enumerate(lines):
            if '```' in line:
                if sol is None:
                    sol = idx
                else:
                    eol = idx
                    break
        if sol is not None and eol is not None:
            code_lines = lines[sol+1:eol]
        elif sol is not None:
            code_lines = lines[sol+1:]
        return '\n'.join(code_lines)
    else:
        return response

def main():
    language, dataset = load_data(args.dataset)
    tree_sitter_languages = {'Python': 'python', 'Java': 'java', 'C++': 'cpp', 'Go': 'go', 'JavaScript': 'javascript'}

    save_dir = os.path.join(args.model, args.task)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    log_file_path = os.path.join(save_dir, f'{args.dataset}_results.jsonl')
    finished_idx = load_finished_data(log_file_path)
    log_file = open(log_file_path, 'a')

    for idx, data in tqdm(enumerate(dataset), total=len(dataset)):
        if idx in finished_idx:
            continue
        
        output_js = {'idx': idx, 'nl': data.nl, 'code': data.code, 'solving_process': [], 'best_score': {}, 'best_representation': None, 'best_generated_code': None}

        js = {}
        prompt_first = gen_repre_first_template.format(code=data.code, language=language)
        conversation = [{"role": "user", "content": prompt_first}]
        response = request_response(conversation, model_name=args.model)
        conversation.append(response)
        generated_re = response.content
        js['generated_re'] = generated_re

        reconstruct_prompt = reconstruct_template.format(representation=generated_re, language=language)
        message = [{"role": "user", "content": reconstruct_prompt}]
        response = request_response(message, model_name=args.model, temperature=0.0).content
        generated_code = extract_code(response)
        js['generated_code'] = generated_code
        equivalence_score = calc_codebleu([data.code], [generated_code], tree_sitter_languages[language], weights=[0, 0.5, 0.5, 0], tokenizer=tokenize_code)['codebleu']

        form_score = compute_form_score(generated_re, language)
        if form_score is None:
            continue
        
        js['equivalence_score'] = equivalence_score
        js['form_score'] = form_score
        output_js['solving_process'].append(js)

        score = equivalence_score + form_score
        if output_js['best_score'] == {} or score > output_js['best_score']['total_score']:
            output_js['best_score']['total_score'] = score
            output_js['best_score']['equivalence_score'] = equivalence_score
            output_js['best_score']['form_score'] = form_score
            output_js['best_representation'] = generated_re
            output_js['best_generated_code'] = generated_code

        attempt = 0
        while (equivalence_score < 0.9 or form_score < 0.9):
            js = {}
            prompt_iterate = gen_repre_iterate_template.format(equivalence_score=equivalence_score, form_score=form_score,
                                                               code=data.code, language=language)            
            conversation += [{"role": "user", "content": prompt_iterate}]
            response = request_response(conversation, model_name=args.model)
            conversation.append(response)
            generated_re = response.content
            js['generated_re'] = generated_re
            
            reconstruct_prompt = reconstruct_template.format(representation=generated_re, language=language)
            message = [{"role": "user", "content": reconstruct_prompt}]
            response = request_response(message, model_name=args.model, temperature=0.0).content
            generated_code = extract_code(response)
            js['generated_code'] = generated_code
            equivalence_score = calc_codebleu([data.code], [generated_code], tree_sitter_languages[language], weights=[0, 0.5, 0.5, 0], tokenizer=tokenize_code)['codebleu']

            form_score = compute_form_score(generated_re, language)
            if form_score is None:
                continue

            js['equivalence_score'] = equivalence_score
            js['form_score'] = form_score
            output_js['solving_process'].append(js)

            score = equivalence_score + form_score
            if score > output_js['best_score']['total_score']:
                output_js['best_score']['total_score'] = score
                output_js['best_score']['equivalence_score'] = equivalence_score
                output_js['best_score']['form_score'] = form_score
                output_js['best_representation'] = generated_re
                output_js['best_generated_code'] = generated_code

            attempt += 1
            if attempt >= 10:
                break
        
        log_file.write(json.dumps(output_js) + '\n')
        log_file.flush()
    log_file.close()
        

if __name__ == "__main__":
    main()