import os
import re
import time
import openai
from openai import OpenAI
import argparse
import traceback
import fire
import random
import pandas as pd
from tqdm import tqdm
from typing import List
from datasets import Dataset, load_dataset
from termcolor import colored
import sys
import ast
import timeout_decorator
import numpy as np
import json
import subprocess

from collections import defaultdict
import torch
from transformers import AutoTokenizer
from tenacity import (
    retry,
    retry_if_not_exception_type,
    stop_after_attempt, # type: ignore
    wait_random_exponential, # type: ignore
)
from multiprocessing import Array, Value

from utils import unsafe_execute, IMPORT_HELPER, HF_DATASET, HF_MODEL_NAMES, extract_code, STOP_WORDS, joint_adv_prompt, joint_completion_prompt
from utils import adv_input_completion_prompt, extract_arguments,  eval_exact_match, extract_output, check_attack_success, deepseek_call, llm_call
sys.path.append(f"evalplus/") 

PASS = "pass"
FAIL = "fail"
TIMEOUT = "timeout"

_SUCCESS = 0
_FAILED = 1
_TIMEOUT = 2
_UNKNOWN = 3

_mapping = {_SUCCESS: PASS, _FAILED: FAIL, _TIMEOUT: TIMEOUT, _UNKNOWN: None}



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default="llama3", type=str)
    parser.add_argument('--eval-base', action='store_true')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--ckpt-dir', default='finetune_data/')
    parser.add_argument('--base-ckpt', default=None)
    parser.add_argument('--k', default=1, type=int)
    parser.add_argument('--random', action='store_true')
    parser.add_argument('--use-temp', action='store_true')
    parser.add_argument('--num-units', default=1, type=int)
    return parser.parse_args()





if __name__ == "__main__":
    args = parse_args()
    args.joint = True        
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    test_lookup = {}

    # Evaluate on HumanEvalFix, since intrinsic evaluation does not need gold unit tests, directly evaluate from the HF dataset
    problems = load_dataset('bigcode/humanevalpack', cache_dir="datasets", trust_remote_code=True)['test'] 
    problems = problems.map(lambda x: {'canonical_solution':x['prompt'] + x['canonical_solution'], 'code': x['buggy_solution']})
    

    if 'gpt' not in args.model and 'deepseek' not in args.model:
        from vllm.lora.request import LoRARequest
        from vllm import LLM, SamplingParams

        model_name = HF_MODEL_NAMES[args.model]
        if args.use_temp:
            temp = 0.8
            top_p = 0.9
        else:
            temp = 0
            top_p = 1
        sampling_params = SamplingParams(temperature=temp, top_p=top_p, n=1, max_tokens=2048 + 512, stop=STOP_WORDS[args.model])

        if args.base_ckpt is not None:
            llm = LLM(args.base_ckpt,
                trust_remote_code=True,
                tensor_parallel_size=torch.cuda.device_count(),
                enable_lora=True, 
                max_lora_rank=32, 
                download_dir=".cache/hub/",
                tokenizer=model_name,
                max_model_len=1024*8)
        else:

            llm = LLM(model_name,
                    trust_remote_code=True,
                    tensor_parallel_size=torch.cuda.device_count(),
                    enable_lora=True, 
                    max_lora_rank=32, 
                    download_dir=".cache/hub/",
                    tokenizer=model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    
    else:
        model_name = args.model

    if args.joint:
        random_pre_adv_prompts = [joint_completion_prompt.format(signature=inst['signature'], description=inst['prompt'], entry_point=inst['entry_point']) for inst in problems]

        if not args.random:
            pre_adv_prompts = [joint_adv_prompt.format(signature=inst['signature'], description=inst['prompt'], code=inst['code'], entry_point=inst['entry_point']) for inst in problems]
        else:
            pre_adv_prompts = random_pre_adv_prompts

        list_outputs = []


        if 'gpt' in model_name:
            for i in range(args.num_units):
                outputs = []
                for prompt in tqdm(pre_adv_prompts):
                    outputs.append(llm_call(model_name, prompt))
                list_outputs.append(outputs)
        
        elif 'deepseek' in model_name:
            for i in range(args.num_units):
                outputs = []
                for prompt in tqdm(pre_adv_prompts):
                    outputs.append(deepseek_call(prompt))
                list_outputs.append(outputs)

        else:
            adv_prompts = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) for prompt in pre_adv_prompts]
            print(colored(adv_prompts[0], 'yellow'))

            for i in range(args.num_units): #4
                if not args.eval_base:
                    outputs = llm.generate(adv_prompts, sampling_params, lora_request=LoRARequest("finetined_adapter", 1, args.ckpt_dir))
                else:
                    outputs = llm.generate(adv_prompts, sampling_params)
                outputs = [output.outputs[0].text for output in outputs]
                list_outputs.append(outputs)

        list_outputs = zip(*list_outputs)
        list_outputs = [list(elem) for elem in list_outputs] 
        attack_tracker = []
        output_tracker = []
        output_acc = []
        overall_rate = []

        if not args.eval_base and args.out_random:
            new_out_prompts = []
            prefix_outs = []
            for prompt, completion in zip(random_pre_adv_prompts, outputs):
                entry = []
                entry.append({'role': 'user', 'content': prompt})
                unit_input =  '## Unit Test' + completion.split('## Unit Test', 1)[-1].split('### Output', 1)[0]
                prefix_outs.append(unit_input)
                entry.append({'role': 'assistant', 'content': unit_input})
                new_out_prompts.append(tokenizer.apply_chat_template(entry, tokenize=False, continue_final_message=True))
            
            outputs = llm.generate(new_out_prompts, sampling_params, lora_request=LoRARequest("finetined_adapter", 1, args.ckpt_dir))
            outputs = [output.outputs[0].text for output in outputs]

            list_outputs = []
            for inp, out in zip(prefix_outs, outputs):
                list_outputs.append([f'{inp}\n{out}'])

            
        print(colored(list_outputs[0][0], 'yellow'))

        for inst, outs in zip(problems, list_outputs):
            for r, response in enumerate(outs):
                if 'Output: ```' not in response:
                    response = response.replace('Output: ``', 'Output: ```')
                unit_input, unit_str = extract_arguments(response, inst['entry_point'], return_str=True)
                if unit_input is None:
                    attack_tracker.append(False)
                    continue
                attack_succ, unit_output, func_out = check_attack_success(inst, inst['code'], unit_input)
                if unit_output is None: attack_succ = False
                attack_tracker.append(attack_succ)
                try:
                    response_output = extract_output(response)
                    try:
                        response_output = eval(response_output)
                        verdict = eval_exact_match(unit_output, response_output, use_set=' set(' in inst['test'])

                    except: 
                        response_output = None
                        verdict = False
                    output_acc.append(verdict)
                except Exception as e:
                    print(e)
                    verdict = False
                    output_acc.append(False)
                if attack_succ and verdict: overall_rate.append(True)
                else: overall_rate.append(False)

                
    
        print(colored(f'Attack Rate: {round(100*sum(attack_tracker)/len(attack_tracker), 2)} | Output Acc: {round(100*sum(output_acc)/len(output_acc), 2)}', 'yellow'))
        print(colored(f'Acc ∩ Attack: {round(100*sum(overall_rate)/len(overall_rate), 2)}'))       
        