import lmql
import lmql.algorithms as la
import asyncio
import json

from lmql.runtime.tokenizer import load_tokenizer

@lmql.query
async def json_seq(s):
    '''lmql
    argmax(openai_chunksize=1024)
        "{s}"
        "Template:"
        """
        {{
            "name": <NAME>,
            "job": <JOB>,
            "role": <ROLE>,
            "education": {{
                "university": <NAME>,
                "degree": <NAME>
            }},
            "interests": <INTERESTS>
        }}

        As JSON: [RESULT]
        """
        try:
            return ("success", json.loads(RESULT), context.prompt)
        except:
            return ("invalid", None, context.prompt)
    from
        "openai/text-curie-001"
    '''


@lmql.query
async def json_sketch(s):
    '''lmql
    argmax(openai_chunksize=1024)
        "{s}\n"
        """As JSON:
        {{
            "name": "[STR_VALUE]",
            "job": "[STR_VALUE]",
            "role": "[STR_VALUE]",
            "education": {{
                "university": "[STR_VALUE]",
                "degree": "[STR_VALUE]"
            }},
            "interests": "[STR_VALUE]"
        }}
        """
        RESULT = context.prompt.split("As JSON:\n")[1].strip()
        try:
            return ("success", json.loads(RESULT), context.prompt)
        except:
            return ("invalid", None, context.prompt)
    from
        "openai/text-curie-001"
    where 
        STOPS_BEFORE(STR_VALUE, '"')
    '''


async def main():
    descriptions = """
1. Sarah Johnson is a software engineer at Google, specializing in machine learning. She holds a Bachelor's degree in Computer Science from Stanford University and enjoys hiking and playing the piano in her free time.

2. Michael Lee is a data analyst at Amazon, working on improving customer experience. He graduated from the University of California, Berkeley with a degree in Statistics and enjoys playing basketball and reading science fiction novels.

3. Emily Chen is a product manager at Microsoft, leading the development of new features for the company's cloud platform. She holds a Master's degree in Business Administration from Harvard Business School and enjoys traveling and trying new foods.

4. David Kim is a cybersecurity specialist at IBM, responsible for protecting the company's networks from cyber attacks. He graduated from the Massachusetts Institute of Technology with a degree in Computer Science and enjoys playing video games and practicing martial arts.

5. Rachel Wong is a user experience designer at Apple, creating intuitive and visually appealing interfaces for the company's products. She holds a Bachelor's degree in Graphic Design from the Rhode Island School of Design and enjoys painting and hiking in her free time.

6. John Nguyen is a software developer at Facebook, working on improving the company's mobile app. He graduated from the University of Texas at Austin with a degree in Computer Engineering and enjoys playing guitar and watching movies.

7. Samantha Patel is a data scientist at Airbnb, analyzing user behavior to improve the company's recommendation system. She holds a Master's degree in Data Science from Columbia University and enjoys practicing yoga and cooking Indian cuisine.

8. Kevin Park is a software engineer at Uber, working on the company's autonomous vehicle project. He graduated from the University of Michigan with a degree in Electrical Engineering and enjoys playing soccer and hiking in his free time.

9. Jessica Lee is a product designer at Dropbox, creating user-friendly interfaces for the company's cloud storage platform. She holds a Bachelor's degree in Industrial Design from the Art Center College of Design and enjoys playing the piano and traveling.

10. Alex Kim is a software architect at Intel, designing and implementing complex systems for the company's processors. He graduated from the University of California, Los Angeles with a degree in Computer Science and enjoys playing video games and practicing photography.
App exited with SIGKILL
"""
    # split by double newline
    biographies = [d.split(".", 1)[1].strip() for d in descriptions.split("\n\n")]
    la.caching(True)
    t = load_tokenizer("gpt2")
    
    # sequential generation
    results = await la.map(json_seq, biographies)
    n_tokens = 0

    for valid, obj, prompt in results:
        total_tokens = "Template:" + prompt.split("Template:", 1)[1].strip()
        n = len(t(total_tokens)["input_ids"])
        n_tokens += n
        print([total_tokens], n)

    print("argmax-sequential: ", len([r for r in results if r[0] == "success"]), "/", len(results))
    print(" - avg tokens: ", n_tokens / len(results), "(includes template and model output)")

    # sketch generation
    results = await la.map(json_sketch, biographies)
    n_tokens = 0

    for valid, obj, prompt in results:
        total_tokens = "As JSON:" + prompt.split("As JSON:", 1)[1].strip()
        print([total_tokens])
        # per line only include the last double quoted value
        output_tokens = []
        for line in total_tokens.split("\n"):
            if '"' in line:
                output_tokens.append(line.rsplit('"', 1)[0].rsplit('"', 1)[1])
        
        n = sum([len(t(o)["input_ids"]) for o in output_tokens])
        n_tokens += n

    print("argmax-sketching: ", len([r for r in results if r[0] == "success"]), "/", len(results))
    print(" - avg tokens: ", n_tokens / len(results), "(includes only sketch output)")

asyncio.run(main())