from ast import Dict
from typing import NamedTuple, Callable
import random
from llm_utils.util import extract_python_code, get_reward_fn, test_reward_fn
# from llm_utils.prompt import prompt_template
from llm_utils.api_key import OPEN_AI_API_KEY
import pickle
from openai import OpenAI
client = OpenAI(api_key=OPEN_AI_API_KEY)

class Sample(NamedTuple):
    reward_fn: str
    fitness: float
    data: Dict

class LLM_Evolution:
    def __init__(   self, 
                    prompt_template: str,
                    test_reward_fn: Callable,
                    to_keep: int, 
                    base_population: list[Sample], 
                    n_crossovers: int, 
                    n_samples_per_crossover: int, 
                    maximize: bool,
                    tmp_dir: str,
                    verbose: bool = False
                    ) -> None:
        self.prompt_template = prompt_template
        self.test_reward_fn = test_reward_fn
        self.to_keep = to_keep
        self.population = base_population
        # assert fitness of every member in base is not None
        assert all(sample.fitness is not None for sample in base_population), "Fitness of a member in base population is None"
        self.n_crossovers = n_crossovers
        self.n_samples_per_crossover = n_samples_per_crossover
        self.maximize = maximize
        self.tmp_dir = tmp_dir
        self.verbose = verbose

    def ask(self) -> list[Sample]:
        new_samples = []
        n_attempts = 0
        while len(new_samples) < self.n_crossovers * self.n_samples_per_crossover:
            parent1, parent2 = random.sample(self.population, 2)
            for _ in range(self.n_samples_per_crossover):
                if self.verbose:
                    print(f"Generating sample {len(new_samples) + 1} of {self.n_crossovers * self.n_samples_per_crossover}")
                for _ in range(3):
                    try:
                        crossover = self.crossover(parent1, parent2)
                        reward_fn = get_reward_fn(crossover.reward_fn, self.tmp_dir)
                        if self.test_reward_fn(reward_fn):
                            new_samples.append(crossover)
                            break
                    except Exception as e:
                        print(f"Error during crossover: {str(e)}")
                        n_attempts += 1
                        if n_attempts > 2*self.n_crossovers*self.n_samples_per_crossover:
                            raise e
        return new_samples

    def crossover(self, parent1: Sample, parent2: Sample) -> Sample:
        # Perform crossover between two samples
        crossover, full_response = self.llm_generate_fn(parent1, parent2)
        return Sample(reward_fn=crossover, fitness=None, data={"full_response": full_response})

    def tell(self, samples: list[Sample]) -> None:
        # Update population with new samples
        self.population.extend(samples)
        # Sort population by fitness
        self.population.sort(key=lambda x: x.fitness, reverse=self.maximize)
        # Keep only top 'to_keep' samples
        self.population = self.population[:self.to_keep]

    def get_best_member(self) -> Sample:
        return self.population[0]

    def llm_generate_fn(self, parent1: Sample, parent2: Sample) -> Sample:
        score1 = parent1.fitness if self.maximize else -parent1.fitness
        score2 = parent2.fitness if self.maximize else -parent2.fitness
        prompt = self.prompt_template.substitute(fn1=parent1.reward_fn, score1=score1, score2=score2, fn2=parent2.reward_fn)
        response = client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[
                {"role": "user", "content": prompt}
            ], 
            temperature=1.0
        )
        return extract_python_code(response.choices[0].message.content), response.choices[0].message.content

    def log_population(self) -> None:
        print("\n================================ Population ==================================\n")
        print("\n".join([
            f"Reward Function: {sample.reward_fn}\nFitness: {sample.fitness}\nMean Return: {sample.data['rewards'][0]}, Std Return: {sample.data['rewards'][1]}\nMean Divergence: {sample.data['divergence'][0]}, Std Divergence: {sample.data['divergence'][1]}, Quality: {sample.data['quality']}\n---------------------------------\n"
            for sample in self.population
        ]))
        print("\n================================================================================\n")

    def save(self, save_dir: str) -> None:
        # save the population
        with open(save_dir + '.pkl', 'wb') as f:
            pickle.dump(self.population, f)
            