from tabulate import tabulate
import argparse

import sys
import os

# Remove any existing paths ending with 'src' and add our src directory
script_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(os.path.dirname(script_dir), 'src')

# Remove existing paths that end with 'src'
sys.path = [path for path in sys.path if not path.endswith('src')]

# Add our src directory
if src_dir not in sys.path:
    sys.path.append(src_dir)

from forward_only.fwd_dqn import FwdDQN
from forward_only.utils.eval import evaluate_policy

import optax
import random

def main(wandb_project_name, recurrent_connections, backward_connections, task_name, seed):
    print(f"Using task: {task_name}")
    print(f"Using seed: {seed}")
    num_eval_episodes = 10

    base_learning_rate = 1e-4
    lr_warmup_steps = 500_000
    total_timesteps = 4_000_000

    model = FwdDQN(
        task_name,
        net_arch=[400, 200, 200],
        q_net_kwargs=dict(folding_mode='msq', normalize=True, normalization_method='l2', goodness_type='std'),
        learning_rate=optax.join_schedules(
                schedules=[
                    # Warmup - linearly increase the learning rate from 0 to
                    # ``base_learning_rate`` over the first ``lr_warmup_steps`` steps
                    optax.linear_schedule(
                        init_value=0., end_value=base_learning_rate,
                        transition_steps=lr_warmup_steps),
                    # Annealing - cosine decay the learning rate over remaining steps
                    optax.cosine_decay_schedule(
                        init_value=base_learning_rate,
                        decay_steps=max(3e-5, total_timesteps - lr_warmup_steps),
                    )
                ],
                boundaries=[lr_warmup_steps]
        ),
        #learning_rate=base_learning_rate,
        huber_loss=False,
        buffer_size=5_000_000,
        target_network_frequency=1000,
        max_grad_norm=1.0,
        batch_size=512,
        exploration_fraction=0.1,
        learning_starts=20_000,
        train_frequency=4,
        seed=seed,
        double_q=True,
        gamma=0.99,
        end_eps=0.01,
        recurrent_connections=recurrent_connections,
        backward_connections=backward_connections,
        average_q_values=True
    )

    print(model)

    train_info = model.learn(track=False,
                             total_timesteps=total_timesteps,
                             record_video=True,
                             eval_frequency=50_000,
                             save_checkpoints=True,
                             wandb_project_name=wandb_project_name)

    # Evaluate each layer of the network
    print('Evaluating each layer of the network...')
    eval_info = []
    for i in range(model.num_layers):
        mean_ep_return, std_ep_return = evaluate_policy(
            model.get_policy(layer_index=i),
            model._env_spec.make_env(
                record_video=False,
                record_video_freq=1,  # Record every episode
                run_log_dir=f'{train_info["log_dir"]}/eval/layer_{i + 1}',
                seed=model.seed
            ),
            num_episodes=num_eval_episodes,
            show_progress=True
        )

        eval_info.append((
            f'Layer {i + 1}',
            f'{mean_ep_return:.3f} ± {std_ep_return:.3f}')
        )

    print(f'Evaluation results (over {num_eval_episodes} episodes):')
    print(tabulate(eval_info, headers=['Layer', 'Episodic return']))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train FwdDQN on specified environment')
    parser.add_argument('--task', type=str, default='MinAtar/Seaquest-v1',
                        help='Environment name (default: MinAtar/Seaquest-v1)')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed (default: 42)')
    
    args = parser.parse_args()
    
    main('MinAtar_Seaquest-v1', recurrent_connections=False, backward_connections=True, 
         task_name=args.task, seed=args.seed)
