from minatar_il_ppo import make_train
from utils import make_expert_transitions, compute_policy_expert_divergence_minatar as compute_policy_expert_divergence
import fire
from llm_utils.llm_evolution import LLM_Evolution, Sample
from llm_utils.util import get_reward_fn
import jax
import jax.numpy as jnp
import os
from tqdm import tqdm
import gymnax
from common import Transition

if __name__ == '__main__':
    
    def main(config):
        save_dir = './llm_meta_minatar_runs/' + config['SAVE_NAME'] + '/'
        os.makedirs(save_dir, exist_ok=True)
        tmp_dir = config['TEMP_DIR']

        dummy_env, dummy_env_params = gymnax.make(config["ENV_NAME"])
        # Get action and observation    
        num_actions = dummy_env.action_space(dummy_env_params).n

        sample_expert_transitions = make_expert_transitions(config)
        sample_expert_transitions = jax.jit(sample_expert_transitions, static_argnums=(0,))

        # print number of gpus
        print(f'Number of GPUs available: {jax.device_count()}')

        def batched_eval(samples: list[Sample]):
            def single_eval(sample: Sample):
                reward_fn = jax.jit(get_reward_fn(sample.reward_fn, tmp_dir))
                train = make_train(config, jax.nn.relu, sample_expert_transitions, reward_fn)
                def single_rollout(rng):
                    results = train(rng)
                    rng, _rng = jax.random.split(rng)
                    policy_states, policy_actions = results['eval_transitions'].unnorm_obs, results['eval_transitions'].action
                    if config['SUB_SAMPLE_RATE'] > 1:
                        _policy_states = policy_states[::config['SUB_SAMPLE_RATE']]
                        policy_states = jnp.concatenate([_policy_states, policy_states[-1:]], axis=0)
                        _policy_actions = policy_actions[::config['SUB_SAMPLE_RATE']]
                        policy_actions = jnp.concatenate([_policy_actions, policy_actions[-1:]], axis=0)

                    divergence = compute_policy_expert_divergence(policy_states, policy_actions, _rng, sample_expert_transitions, num_actions, config['DISC_INP'])
                    return results, divergence
                
                rng = jax.random.PRNGKey(42)
                rng, _rng = jax.random.split(rng)
                rngs = jax.random.split(_rng, config['NUM_SEEDS'])
                if jax.device_count() > 1:
                    device_count = jax.device_count()
                    num_batches = config['NUM_SEEDS'] // device_count
                    # Reshape rngs for batching
                    try:
                        batched_rngs = jnp.reshape(rngs, (device_count, num_batches) + rngs.shape[1:])
                        # print('Batched rngs shape:', batched_rngs.shape)
                    except ValueError as e:
                        print(f"Error reshaping RNGs: {e}. RNGs shape: {rngs.shape}, Target shape: {(device_count, num_batches) + rngs.shape[1:]}")
                        raise e

                    pmap_train = jax.pmap(jax.vmap(single_rollout, in_axes=(0,)), in_axes=(0,)) # pmap over the device dimension
                    # print(f'Using Pmap with {device_count} devices.')
                
                    results, divergences = pmap_train(batched_rngs)

                    # Combine results by flattening the first two axes (device, batch)
                    results = jax.tree_util.tree_map(
                        lambda x: jnp.reshape(x, (config['NUM_SEEDS'],) + x.shape[2:]),
                        results
                    )
                    divergences = jnp.reshape(divergences, (config['NUM_SEEDS'],))
                
                else:
                    vmap_train = jax.jit(jax.vmap(single_rollout, in_axes=(0,)))
                    results, divergences = vmap_train(rngs)
                
                # move results and divergences to CPU
                results = jax.device_put(results, jax.devices("cpu")[0])
                divergences = jax.device_put(divergences, jax.devices("cpu")[0])

                # filter nans (rare but possible)
                nan_ids = jnp.isnan(divergences)
                results['metrics'] = jax.tree_util.tree_map(lambda x: x[~nan_ids], results['metrics'])
                divergences = divergences[~nan_ids]
                quality = 1 - nan_ids.sum() / config['NUM_SEEDS']
                
                avg_returns = results['metrics']['returned_episode_returns'].mean(axis=(-1, -2, 0))[-1]
                std_returns = results['metrics']['returned_episode_returns'].mean(axis=(-1, -2)).std(axis=0)[-1]

                mean_divergence = divergences.mean()
                std_divergence = divergences.std()

                return mean_divergence, std_divergence, avg_returns, std_returns, quality

            new_samples = []
            for sample in tqdm(samples):
                mean_divergence, std_divergence, avg_returns, std_returns, quality = single_eval(sample)
                if config['MAXIMIZE']:
                    fitness = avg_returns
                else:
                    fitness = mean_divergence
                    if config['ENV_NAME'] == 'Asterix-MinAtar': # Special case for Asterix (terminating early leads to low divergence)
                        if fitness <= 11.0:
                            fitness = 25.0
                
                new_samples.append(
                    Sample(
                        reward_fn=sample.reward_fn,
                        fitness=fitness,
                        data={
                            'rewards': (avg_returns, std_returns),
                            'divergence': (mean_divergence, std_divergence),
                            'quality': quality
                        }
                    )
                )
            return new_samples
        
        base_population = []
        base_dir = './base'
        for fname in os.listdir(base_dir):
            if not fname.endswith('.py'):
                continue
            path = os.path.join(base_dir, fname)
            # read each file
            with open(path, 'r', encoding='utf-8') as f:
                content = f.read()
            # store filename + content (or just content)
            base_population.append(
                Sample(
                    reward_fn=content,
                    fitness=None,
                    data=None,
                )
            )

        # eval base population
        print('Evaluating base population...')
        base_population = batched_eval(base_population)

        # sort base population
        base_population.sort(key=lambda x: x.fitness, reverse=config['MAXIMIZE'])
        best_base = base_population[0]

        print(f'Best base: {best_base.reward_fn} \nFitness: {best_base.fitness} \nMean Return: {best_base.data["rewards"][0]}, Std Return: {best_base.data["rewards"][1]} \nMean Divergence: {best_base.data["divergence"][0]}, Std Divergence: {best_base.data["divergence"][1]}')

        # run meta evolution
        evolution = LLM_Evolution(
            to_keep=config['TO_KEEP'],
            base_population=base_population,
            n_crossovers=config['N_CROSSEVERS'],
            n_samples_per_crossover=config['N_SAMPLES_PER_CROSSOVER'],
            maximize=config['MAXIMIZE'],
            tmp_dir=config['TEMP_DIR']
        )
        evolution.log_population()

        for gen in range(config['N_GENERATIONS']):
            print(f'\nGeneration {gen}')
            # ask 
            samples = evolution.ask()
            # eval
            samples = batched_eval(samples)
            # tell
            evolution.tell(samples)
            # log
            evolution.log_population()

        # save the population
        evolution.save(save_dir)

    def get_configs(
        # PPO Configs
        LR: float = 0.005,
        ENV_NAME: str = "SpaceInvaders-MinAtar",
        NUM_ENVS: int = 64,
        NUM_STEPS: int = 128,
        TOTAL_TIMESTEPS: int = 10_000_000,
        UPDATE_EPOCHS: int = 4,
        NUM_MINIBATCHES: int = 8,
        GAMMA: float = 0.99,
        GAE_LAMBDA: float = 0.95,
        CLIP_EPS: float = 0.2,
        ENT_COEF: float = 0.01,
        VF_COEF: float = 0.5,
        MAX_GRAD_NORM: float = 0.5,
        ANNEAL_LR: bool = True,
        NUM_SEEDS: int = 16,
        BACKEND: str = 'positional',
        EVAL_NUM_ENVS: int = 10,
        # Discriminator Configs
        DISC_LR: float = 3e-4,
        DISC_UPDATE_EPOCHS: int = 1,
        GP_WEIGHT: float = 0.1,
        USE_FEATURES: bool = False,
        SUB_SAMPLE_RATE: int = 20,
        N_EXPERT_TRAJS: int = 10,
        USE_SPECTRAL_NORM: bool = False,
        DISC_INP: str = 'sa',
        # Meta-Evolution Configs
        N_GENERATIONS: int = 10,
        TO_KEEP: int = 10,
        N_CROSSEVERS: int = 20,
        N_SAMPLES_PER_CROSSOVER: int = 1,
        MAXIMIZE: bool = False,
        # Logging
        USE_WANDB: bool = False,
        DEBUG: bool = False,
        SAVE_NAME: str = 'run0',
        TEMP_DIR: str = 'tmp_dir'
    ):
        # return a dictionary of configs
        return {
            "LR": LR,
            "NUM_ENVS": NUM_ENVS,
            "NUM_STEPS": NUM_STEPS,
            "TOTAL_TIMESTEPS": TOTAL_TIMESTEPS,
            "ENV_NAME": ENV_NAME,
            "UPDATE_EPOCHS": UPDATE_EPOCHS,
            "NUM_MINIBATCHES": NUM_MINIBATCHES,
            "GAMMA": GAMMA,
            "GAE_LAMBDA": GAE_LAMBDA,
            "CLIP_EPS": CLIP_EPS,
            "ENT_COEF": ENT_COEF,
            "VF_COEF": VF_COEF,
            "MAX_GRAD_NORM": MAX_GRAD_NORM,
            "ANNEAL_LR": ANNEAL_LR,
            "NUM_SEEDS": NUM_SEEDS,
            "BACKEND": BACKEND,
            "EVAL_NUM_ENVS": EVAL_NUM_ENVS,
            "DISC_LR": DISC_LR,
            "DISC_UPDATE_EPOCHS": DISC_UPDATE_EPOCHS,
            "GP_WEIGHT": GP_WEIGHT,
            "USE_FEATURES": USE_FEATURES,
            "SUB_SAMPLE_RATE": SUB_SAMPLE_RATE,
            "N_EXPERT_TRAJS": N_EXPERT_TRAJS,
            "USE_SPECTRAL_NORM": USE_SPECTRAL_NORM,
            "DISC_INP": DISC_INP,
            "USE_WANDB": USE_WANDB,
            "DEBUG": DEBUG,
            "SAVE_NAME": SAVE_NAME,
            "N_GENERATIONS": N_GENERATIONS,
            "TO_KEEP": TO_KEEP,
            "N_CROSSEVERS": N_CROSSEVERS,
            "N_SAMPLES_PER_CROSSOVER": N_SAMPLES_PER_CROSSOVER,
            "MAXIMIZE": MAXIMIZE,
            "TEMP_DIR": TEMP_DIR
        }

    config = fire.Fire(get_configs)
    main(config)

            
