
from pathlib import Path
import random
import matplotlib
matplotlib.use('Agg') 
import matplotlib.pyplot as plt
import json
import numpy as np
from sb3_contrib import MaskablePPO
from args import Args
from environment import BScheduler
from stable_baselines3.common.env_util import SubprocVecEnv, make_vec_env, DummyVecEnv
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from stable_baselines3.common.callbacks import CheckpointCallback
import gymnasium as gym
import wandb
from fast_feature_extractor import OptimizedHierarchicalBPlusFeatureExtractor
import os

def get_env(args: Args):
    env = get_env_func(args)()
    env.reset(seed=args.seed)
    return env

def get_env_func(args: Args):
    def env_func():
        env = gym.make("BScheduler-v0", args=args, render_mode=None)
        return env
 
    return env_func

def get_vec_env(args: Args, n_envs_override: int = None):
    n_envs = n_envs_override or args.n_envs 
    env_func = get_env_func(args)
    env = make_vec_env(env_func, n_envs=n_envs, vec_env_cls=DummyVecEnv)
    return env

def save_args(args: Args, folder: Path):
    folder = Path(folder)
    folder.mkdir(parents=True, exist_ok=True)
    args.save(folder / "args.json")

def load_args(folder: str) -> Args:
    args = Args()
    args.load(os.path.join(folder,"args.json"))
    return args

def load_args_env(folder: str):
    args = load_args(folder)
    env = get_vec_env(args, 1)
    return args, env

def load_model(path_to_zip: str):
    model = MaskablePPO.load(path_to_zip)
    return model

def create_model(args: Args, env: BScheduler, eval_env: BScheduler):
    print(f"Creating model with args: {args}")
    random.seed(args.seed)
    np.random.seed(args.seed)

    tensorboard_log = Path("local/logs")
    callbacks = []
    if not args.no_wandb:
        import wandb
        from wandb.integration.sb3 import WandbCallback
        callbacks.append(WandbCallback(verbose=2))
        tensorboard_log = Path(f"{wandb.run.dir}/logs")
        print(tensorboard_log)
    best_model_save_path = tensorboard_log.parent
    print("best_model_save", best_model_save_path)
    save_args(args, best_model_save_path)


    eval_callback = MaskableEvalCallback(
        eval_env,
        best_model_save_path=best_model_save_path,
        log_path=tensorboard_log,
        eval_freq=args.s_eval_freq, 
        n_eval_episodes=args.s_n_eval_episodes,
        deterministic=True,
        render=False,
    )
    callbacks.append(eval_callback)


    checkpoint_callback = CheckpointCallback(
        save_freq=args.checkpoint_callback_freq, 
        save_path=best_model_save_path / "checkpoints",  
        name_prefix="rl_model_checkpoint"  
    )
    callbacks.append(checkpoint_callback)

    policy_kwargs = dict(
        net_arch=args.s_net_arch,
    )

    if args.feature_extractor == "tfh_fast": # Fast Hierarchical Transformer Feature Extractor
        policy_kwargs['features_extractor_class'] = OptimizedHierarchicalBPlusFeatureExtractor
        policy_kwargs['features_extractor_kwargs'] = dict(
            feature_dim=args.s_transformer_features_dim,
            values_per_node=args.env_max_values_per_node,
            num_ops=args.env_num_inserts + args.env_num_deletes,
            num_heads=args.s_transformer_nhead,
            dropout=0.1,
            max_levels=5,
        )


    elif args.feature_extractor == 'none':
        print("Training without special feature extractor")

    else:
        raise ValueError(f"Unknown feature_extractor: {args.feature_extractor}")

    model = MaskablePPO(
        "MlpPolicy",
        env,
        n_epochs=args.s_n_epochs,
        learning_rate=args.s_learning_rate,
        verbose=1,
        gamma=0.999,
        device=args.device,
        policy_kwargs=policy_kwargs,
        tensorboard_log=tensorboard_log,
        seed=args.seed,
        batch_size=args.s_batch_size,
    )

    print("Model policy:", model.policy)

    return model, callbacks, best_model_save_path