import gymnasium as gym
import numpy as np
import base64
from PIL import Image
import io
import os

class DQNEvaluator:
    def evaluate(self, agent, env_seeds, max_steps=200):
        """Evaluate the agent on multiple seeds and return metrics"""
        total_rewards = []
        total_fuel = 0
        success_count = 0
        episodes_recorder = {}
        image64s = []
        observations = []

        for i, seed in enumerate(env_seeds):
            env = gym.make("LunarLander-v3", render_mode='rgb_array')
            observation, _ = env.reset(seed=seed)
            episode_reward = 0
            episode_fuel = 0
            episode_observations = []
            canvas = None

            for step in range(max_steps):
                action = agent.act(observation, training=False)
                observation, reward, terminated, truncated, info = env.step(action)

                episode_reward += reward
                if action in [1, 2, 3]:  # Actions that use fuel
                    episode_fuel += 1

                episode_observations.append(observation.tolist())

                # Capture rendering (simplified version)
                if step % 10 == 0:
                    img = env.render()
                    if canvas is None:
                        canvas = np.zeros_like(img, dtype=np.float32)
                    mask = np.any(img != [0, 0, 0], axis=-1)
                    alpha = step / max_steps
                    canvas[mask] = img[mask] * alpha + canvas[mask] * (1 - alpha)

                if terminated or truncated:
                    break

            # Final render
            img = env.render()
            if canvas is None:
                canvas = np.zeros_like(img, dtype=np.float32)
            mask = np.any(img != [0, 0, 0], axis=-1)
            canvas[mask] = img[mask]

            # Convert canvas to base64
            img_str = self._image_to_base64(canvas)

            env.close()

            # Record results
            total_rewards.append(episode_reward)
            total_fuel += episode_fuel
            image64s.append(img_str)
            observations.append(episode_observations)

            if episode_reward >= 200:
                success_count += 1

            episodes_recorder[f'{i}'] = {
                'seed': seed,
                'episode_reward': episode_reward,
                'episode_fuel': episode_fuel,
                'observations': episode_observations,
                'terminated': terminated,
                'truncated': truncated
            }

        # Calculate metrics
        mean_reward = np.mean(total_rewards)
        mean_fuel = total_fuel / len(env_seeds)
        success_rate = success_count / len(env_seeds)

        # Normalized Weighted Score (α=0.6, β=0.2, γ=0.2)
        nws = (mean_reward / 200) * 0.6 + (1 - min(mean_fuel / 100, 1)) * 0.2 + success_rate * 0.2

        # Get the worst performance case
        worst_idx = np.argmin(total_rewards)

        return {
            'mean_reward': mean_reward,
            'mean_fuel': mean_fuel,
            'success_rate': success_rate,
            'nws': nws,
            'worst_case_image': image64s[worst_idx],
            'worst_case_observations': observations[worst_idx],
            'episodes_recorder': episodes_recorder
        }

    def _image_to_base64(self, image_array):
        img_pil = Image.fromarray(image_array.astype(np.uint8))
        buffered = io.BytesIO()
        img_pil.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

    def save_eval_image(self, img_base64, save_dir, filename):
        img_path = os.path.join(save_dir, filename)
        with open(img_path, 'wb') as f:
            f.write(base64.b64decode(img_base64))
        return img_path