import os 
import json 
import random 
import traceback 
from typing import List ,Tuple 

import datasets 
import fire 
from tqdm import tqdm 
from openai import AzureOpenAI 

from dotenv import load_dotenv 
from pydantic import BaseModel 

from common import Sample 

load_dotenv ()


def extract_code_content (code_block :str )->str :
    """
    Removes any triple backticks (``` or ```python, etc.) from a code block
    and returns the inner code as a string.
    """
    lines =code_block .splitlines ()

    if lines and not lines [0 ].strip ():
        lines =lines [1 :]

    if lines and lines [0 ].strip ().startswith ("```"):
        lines =lines [1 :]
    if lines and lines [-1 ].strip ().startswith ("```"):
        lines =lines [:-1 ]
    return "\n".join (lines )


def extract_io_pairs_chatgpt (text :str ,code :str )->List [Tuple [str ,str ,str ]]:
    """
    Extract Input-Output pairs and function name by calling the OpenAI ChatCompletion endpoint.
    Returns a list of tuples (function_name, input, output).
    """
    try :
        client =AzureOpenAI (
        azure_endpoint =os .environ ["OPENAI_API_BASE"],
        api_key =os .environ ["OPENAI_API_KEY"],
        api_version ="2024-08-01-preview",
        )

        class IOPair (BaseModel ):
            function_name :str 
            input :str 
            output :str 

        completion =client .beta .chat .completions .parse (
        model =os .environ ["OPENAI_DEPLOYMENT"],
        messages =[
        {
        "role":"system",
        "content":"""Extract function name and input/output pairs from the problem description.
                If example input/output pairs are present, extract those. If not, generate appropriate test cases.
                Make sure the input matches the function parameters and the output matches the expected return type.""",
        },
        {"role":"user","content":text },
        ],
        response_format =IOPair ,
        )

        parsed_data =completion .choices [0 ].message .parsed 
        print (parsed_data )

























        return [(parsed_data .function_name ,parsed_data .input ,parsed_data .output )]

    except Exception as e :
        print ("Error in ChatCompletion:",e )
        traceback .print_exc ()
        return []


def create_pyx_io (output_path :str ,max_samples :int =25000 )->List [Sample ]:
    """
    Example function to load a dataset and attempt to extract (input, output) pairs.
    KeyboardInterruptが発生した時点で処理した分だけ追記して終了する。
    """
    random .seed (42 )
    rows =datasets .load_dataset ("semcoder/PyX",split ="train")
    samples :List [Sample ]=[]
    cnt =0 
    monologue_cnt =0 
    rows =rows .select (range (20617 ,len (rows )))

    with open (output_path ,"a",encoding ="utf-8")as f :
        try :
            for row in tqdm (rows ):
                if "[MONOLOGUE]"in row ["response"]:
                    monologue_cnt +=1 
                    continue 

                prompt_text =f"Code{row ['id']}:\n{row ['response']}\n{row ['nl']}"
                print (prompt_text )
                print ("="*100 )

                code =extract_code_content (row ["response"])

                pairs =extract_io_pairs_chatgpt (prompt_text ,code )

                for p in pairs :
                    function_name =p [0 ]
                    code_input =extract_code_content (p [1 ])
                    code_output =extract_code_content (p [2 ])
                    print (f"Function: {function_name }")
                    print (f"Input:\n{code_input }")
                    print (f"Output:\n{code_output }")

                    new_sample =Sample (
                    sample_id =str (row ["id"]),
                    code =row ["response"],
                    input =code_input ,
                    output =code_output ,
                    function_name =function_name ,
                    )
                    samples .append (new_sample )


                    json_line ={
                    "sample_id":new_sample .sample_id ,
                    "code":new_sample .code ,
                    "function_name":new_sample .function_name ,
                    "input":new_sample .input ,
                    "output":new_sample .output ,
                    }
                    f .write (json .dumps (json_line ,ensure_ascii =False )+"\n")

                cnt +=1 
                if cnt >=max_samples :
                    break 
        except KeyboardInterrupt :
            print ("\nKeyboardInterrupt detected. Saving collected samples so far...")

    print (f"cnt: {cnt }")
    print (f"monologue_cnt: {monologue_cnt }")
    print (f"Collected {len (samples )} valid samples from PyX dataset.")
    return samples 


def main (output_path :str ="pyx_samples.jsonl"):


    samples =create_pyx_io (output_path =output_path )
    print (f"Finished. {len (samples )} samples processed in total.")


if __name__ =="__main__":
    fire .Fire (main )
