import os
import pickle
import openai
import numpy as np
import copy

from step_funcs import *

class Program:
    def __init__(self,prog_str,init_state=None):
        self.prog_str = prog_str
        self.state = init_state if init_state is not None else dict()
        self.instructions = self.prog_str.split('\n')


class ProgramGenerator():
    def __init__(self,prompter, model_name="gpt-3.5-turbo", temperature=0.7,top_p=0.5,prob_agg='mean'):
        openai.api_key = os.getenv("OPENAI_API_KEY")
        self.prompter = prompter
        self.model_name = model_name
        self.temperature = temperature
        self.top_p = top_p
        self.prob_agg = prob_agg

    def compute_prob(self,response):
        eos = '<|endoftext|>'
        for i,token in enumerate(response.choices[0]['logprobs']['tokens']):
            if token==eos:
                break

        if self.prob_agg=='mean':
            agg_fn = np.mean
        elif self.prob_agg=='sum':
            agg_fn = np.sum
        else:
            raise NotImplementedError

        return np.exp(agg_fn(
            response.choices[0]['logprobs']['token_logprobs'][:i]))

    def generate(self, inputs):
        response = openai.ChatCompletion.create(
            model=self.model_name,  # or "gpt-4" if available
            messages=[
                # {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": self.prompter(inputs)}
            ],
            temperature=self.temperature,
            max_tokens=512,
            top_p=self.top_p,
            frequency_penalty=0,
            presence_penalty=0
        )
        prog = response.choices[0].message['content'].strip()
        return prog, None


class ProgramInterpreter:
    def __init__(self):
        # Register the spatiotemporal step interpreters
        self.step_interpreters = register_step_interpreters()

    def execute_step(self, prog_step, inspect=False):
        # Parse the step to get the step name
        parsed_step = parse_step(prog_step.prog_str, partial=True)
        step_name = parsed_step['step_name']

        # Ensure the step interpreter is registered
        if step_name not in self.step_interpreters:
            print(step_name)
            raise ValueError(f"Step '{step_name}' not registered in interpreters.")

        # Execute the step using the corresponding interpreter
        result = self.step_interpreters[step_name].execute(prog_step, inspect)

        # Handle inspection mode with plain text summaries
        if inspect:
            if isinstance(result, tuple) and len(result) == 2:
                return result  # (output, text_summary)
            else:
                return result, f"Step '{step_name}' executed with no additional output."
        else:
            return result

    def execute(self, prog, init_state=None, inspect=False):
        # Initialize program if provided as a string
        if isinstance(prog, str):
            prog = Program(prog, init_state or {})
        else:
            assert isinstance(prog, Program), "prog must be a string or an instance of Program."

        # Create program steps from instructions
        prog_steps = [Program(instruction, init_state=prog.state) for instruction in prog.instructions]

        # Prepare for textual output if inspection is enabled
        summary_text = ""
        step_output = None

        # Execute each step
        for prog_step in prog_steps:
            if inspect:
                step_output, step_summary = self.execute_step(prog_step, inspect)
                # print(step_summary)
                # print(step_output)
                summary_text += step_summary + "\n"
            else:
                step_output = self.execute_step(prog_step, inspect)

        # Return appropriate results
        if inspect:
            return step_output, prog.state, summary_text
        return step_output, prog.state