from tqdm import tqdm
import numpy as np
import jax
import jax.random as jrandom
from utils import normalization, min_max_normalization, normalize_rewards
import os
import wandb

def evaluate_policy(config, policy, env, save_dir, num_episodes=3, max_steps=500, t_env=None, key=jax.random.PRNGKey(0), return_episode_returns=False):
    policy.eval()
    raw_returns = []
    normalized_returns = []
    discounted_raw_returns = []          
    discounted_normalized_returns = []
    
    use_cnn = getattr(config, 'use_cnn', False)
    
    # Use a mutable counter for sampling
    step_counter = [0]
    debug_printed = [False]
    
    def select_action(observation):
        if use_cnn:
            obs_batched = observation[None, ...]
        else:
            obs_batched = observation[None, :]
        dist = policy(obs_batched)
        if config.is_discrete:
            # Use step counter to vary the random key
            step_counter[0] += 1
            subkey = jrandom.PRNGKey(step_counter[0])
            
            action = dist.sample(seed=subkey)[0]
        else:
            action = dist.mean().flatten()
        return action
    
    def preprocess_state(state):
        """Preprocess state for policy input based on observation type."""
        if use_cnn:
            if state.max() > 1.0:
                return state.astype(np.float32) / 255.0
            return state.astype(np.float32)
        else:
            return normalization(state, config.state_mean, config.state_std)

    action_counts = {}  # Debug: track action distribution
    
    for iter in range(num_episodes):
        try:
            env.seed(iter)
            state = env.reset()
        except (TypeError, AttributeError):
            state = env.reset(seed=iter)
            if isinstance(state, tuple):
                state = state[0]
        
        done = False
        steps = 0
        raw_rewards_list = []
        normalized_rewards_list = []
        discounted_raw_rewards_list = []
        discounted_normalized_rewards_list = []
        steps_list = []
        
        while not done and steps < max_steps:
            s_t = preprocess_state(state)
            if config.is_discrete:
                action = int(select_action(s_t))
                # Track action distribution
                action_counts[action] = action_counts.get(action, 0) + 1
            else:
                action = (select_action(s_t) * config.ACTION_SCALE + config.ACTION_BIAS).astype(np.float32)
            
            step_result = env.step(action)
            if len(step_result) == 4:
                state, _, done, info = step_result
            else:
                # Gymnasium returns (obs, reward, terminated, truncated, info)
                state, _, terminated, truncated, info = step_result
                done = terminated or truncated
        

            raw_rewards = info['obj']
            # If ignoring fuel, only use first 2 objectives (ore1, ore2)
            if getattr(config, 'ignore_fuel', False):
                raw_rewards = raw_rewards[:2]
            raw_rewards_list.append(raw_rewards)
            discounted_raw_rewards = raw_rewards * (config.gamma ** steps)
            discounted_raw_rewards_list.append(discounted_raw_rewards)
            if config.normalize_reward:
                normalized_rewards = min_max_normalization(
                    raw_rewards,
                    getattr(config, "reward_min", 0),
                    getattr(config, "reward_max", 1),
                )
            else:
                normalized_rewards = raw_rewards
            normalized_rewards_list.append(normalized_rewards)
            discounted_normalized_rewards = normalized_rewards * (config.gamma ** steps)
            discounted_normalized_rewards_list.append(discounted_normalized_rewards)
            
            steps += 1
    
        steps_list.append(steps)
        raw_returns.append(np.sum(raw_rewards_list, axis=0))
        normalized_returns.append(np.sum(normalized_rewards_list, axis=0))
        discounted_raw_returns.append(np.sum(discounted_raw_rewards_list, axis=0))
        discounted_normalized_returns.append(np.sum(discounted_normalized_rewards_list, axis=0))

    avg_raw_returns = np.mean(raw_returns, axis=0)
    avg_normalized_returns = np.mean(normalized_returns, axis=0)
    avg_discounted_raw_returns = np.mean(discounted_raw_returns, axis=0)
    avg_discounted_normalized_returns = np.mean(discounted_normalized_returns, axis=0)
    avg_steps = np.mean(steps_list)
    eps = 1e-8  # for numerical stability
    avg_raw_nsw_score = np.mean(np.sum(np.log(np.maximum(np.array(raw_returns), eps)), axis=1))
    avg_normalized_nsw_score = np.mean(np.sum(np.log(np.maximum(np.array(normalized_returns), eps)), axis=1))
    avg_discounted_raw_nsw_score = np.mean(np.sum(np.log(np.maximum(np.array(discounted_raw_returns), eps)), axis=1))
    avg_discounted_normalized_nsw_score = np.mean(np.sum(np.log(np.maximum(np.array(discounted_normalized_returns), eps)), axis=1))
    avg_raw_usw_score = np.mean(np.sum(raw_returns, axis=1))
    avg_normalized_usw_score = np.mean(np.sum(normalized_returns, axis=1))
    avg_raw_discounted_usw_score = np.mean(np.sum(discounted_raw_returns, axis=1))
    avg_normalized_discounted_usw_score = np.mean(np.sum(discounted_normalized_returns, axis=1))
    
    if t_env is not None:
        if t_env == config.total_train_steps:
            np.save(os.path.join(save_dir, f"raw_returns_step_{t_env}.npy"), raw_returns)
            np.save(os.path.join(save_dir, f"normalized_returns_step_{t_env}.npy"), normalized_returns)
            np.save(os.path.join(save_dir, f"steps_step_{t_env}.npy"), steps_list)

        if config.wandb:
            for i in range(config.reward_dim):
                wandb.log({
                    f"eval/avg_raw_return_{i}": avg_raw_returns[i],
                    f"eval/avg_normalized_return_{i}": avg_normalized_returns[i],
                    f"eval/avg_discounted_raw_return_{i}": avg_discounted_raw_returns[i],
                    f"eval/avg_discounted_normalized_return_{i}": avg_discounted_normalized_returns[i],
                }, step=t_env)
            wandb.log({
                "eval/avg_steps": avg_steps,
                "eval/avg_normalized_nsw_score": avg_normalized_nsw_score,
                "eval/avg_normalized_usw_score": avg_normalized_usw_score,
                "eval/avg_raw_discounted_nsw_score": avg_discounted_raw_nsw_score,
                "eval/avg_raw_discounted_usw_score": avg_raw_discounted_usw_score,
                "eval/avg_normalized_discounted_nsw_score": avg_discounted_normalized_nsw_score,
                "eval/avg_normalized_discounted_usw_score": avg_normalized_discounted_usw_score,
                "eval/avg_raw_nsw_score": avg_raw_nsw_score,
                "eval/avg_raw_usw_score": avg_raw_usw_score,
            }, step=t_env)
        else:
            pass
    
    # Debug: print action distribution (for discrete actions)
    if action_counts and t_env is not None:
        total_actions = sum(action_counts.values())
        print(f"  Action distribution at step {t_env}: ", end="")
        for a in sorted(action_counts.keys()):
            pct = 100 * action_counts[a] / total_actions
            print(f"a{a}={pct:.1f}% ", end="")
        print()

    if return_episode_returns:
        return avg_raw_returns, normalized_returns, avg_steps, raw_returns
    return avg_raw_returns, normalized_returns, avg_steps