from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import json
from tqdm import tqdm
import re
import ast
import argparse
import numpy as np
from matplotlib import pyplot as plt
import torch
from pebble import ProcessPool
from copy import deepcopy
from itertools import combinations
from collections import Counter
import random


def get_raw_code(code):
    import re
    raw_code = code
    if '"""' in code:
        code = re.sub('"""[\s\S]*?"""', '', code)
    elif "'''" in code:
        code = re.sub("'''[\s\S]*?'''", '', code)
        
    lines = code.split("\n")
    out_lines = []
    for line in lines:
        if not any(line.strip()):
            continue
        if line.strip().startswith("#"):
            continue
        if "#" in line:
            tmp = line.split("#")[0]
            line = tmp
        out_lines.append(line)
    return '\n'.join(out_lines)


def clean(raw_plan, question=None, answer=None):
    raw_plan = raw_plan.replace("<|eot_id|>","")
    assert '```' in raw_plan
    end_idx = raw_plan.index("```")
    raw_plan = raw_plan[:end_idx] 
    plan = get_raw_code(raw_plan)    
    return plan


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument("--save_path", type=str)
    parser.add_argument("--max_len", type=int, default=2048)
    
    args = parser.parse_args()
    args.final_save_path = args.save_path[:-6] + '_processed.jsonl'
    return args

def prepare_prompts(tokenizer, data):
    out = []
    for i in data:
        prompt = f"Prompt:\n{i['question']}\n\Response:\n{i['answer']}\n\n"
        prompt += "Given a prompt-response pair, your task is to describe the high-level LOGIC of the response using a pseudo Python code. Such that following this code, models can easily generate the response."
        prompt += "\nThe code should balance conciseness and informativeness.\nThe code should be high-level, instead of replicating low-level details in the response.\nThe code should be less than 200 words (adjust its length based on response lengths)."
        prompt += '\ncode (Pseudo Python Code, starts with "```python" and ends with "```"):\n'

        chat = [
            {"role": "user", "content": prompt}
        ]
        prompt = tokenizer.apply_chat_template(chat, tokenize=False)
        prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        out.append(prompt)
    print(out[0])
    return out


def main(args, force=False):
    tokenizer = AutoTokenizer.from_pretrained(args.ckpt)

    with open(args.data_path, 'r') as f:
        raw_data = [json.loads(i) for i in f]
    
    success_ids = set()
    exist_data = []
    qa_cnt = 0
    ferror = open("error.jsonl", 'w')
    if os.path.exists(args.save_path):
        invalid_cnt = 0
        with open(args.save_path) as f:
            tmp = [json.loads(i) for i in f][::-1]
            print('candidate size = ', len(tmp))
            for i in tqdm(tmp):
                key = i['id']
                if key in success_ids:
                    continue
                try:
                    i['plan'] = clean(i['gen'], i['question'], i['answer'])
                    qa_cnt += i['plan'].count("qa_api(")
                    success_ids.add(key)
                    exist_data.append(i)
                except Exception as e:
                    ferror.write(json.dumps(i)+'\n')
                    invalid_cnt += 1
                    continue
    ferror.close()
    
    last_data = []
    for idx, i in enumerate(raw_data):
        if idx not in success_ids:
            i['id'] = idx
            last_data.append(i)
            
    avg_qa_cnt = None if len(success_ids) == 0 else qa_cnt / len(success_ids)
    print(f"total size = {len(raw_data)}, avg qa cnt = {avg_qa_cnt}, success data size = {len(success_ids)}, last data = {len(last_data)}")
    
    if len(last_data) == 0 or force:
        with open(args.final_save_path, 'w') as f:
            for i in exist_data:
                i.pop('gen')
                f.write(json.dumps(i) + '\n')
    else:
        data = prepare_prompts(tokenizer, last_data)
        stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>"), tokenizer.convert_tokens_to_ids("\n```")]
        llm = LLM(model=args.ckpt, tensor_parallel_size=torch.cuda.device_count())  
        sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=args.max_len, stop_token_ids=stop_token_ids)

        fout = open(args.save_path, 'a')
        print(data[0])
        batch_size = 10000
        for start in range(0, len(data), batch_size):
            completions = llm.generate(data[start: start + batch_size], sampling_params)

            for item, output in tqdm(zip(last_data[start: start + batch_size], completions)):
                if len(output.outputs) == 1:
                    output = output.outputs[0].text.strip()
                else:
                    output = [i.text.strip() for i in output.outputs]
                item['gen'] = output
                fout.write(json.dumps(item)+'\n')

            fout.flush()

if __name__ == '__main__':
    args = get_args()
    main(args)
