#!/usr/bin/env python
import argparse
from curses import raw
import shutil
import gym
import os

import ray
from ray import tune
from ray.rllib.agents.registry import get_trainer_class
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.custom_metrics_and_callbacks import MyCallbacks
from ray.tune.logger import pretty_print
import yaml
from .rlserver import INPUT_SIZE

SERVER_ADDRESS = "localhost"
CHECKPOINT_FILE = "./weights/last_checkpoint_{}.out"
CHECKPOINTS_DIR = "./weights/checkpoints/"
version = 0


def save_cpp_model(export_dir):
    if os.path.exists(f'{export_dir}/{version}'):
        shutil.rmtree(f'{export_dir}/{version}')

    trainer.export_model(
        export_dir=f'{export_dir}/{version}', export_formats="model")



def reload_weights(trainer, checkpoint_path, ray_weights_path):
    try:
        if not (os.path.exists(checkpoint_path) and os.path.exists(ray_weights_path) and len(os.listdir(ray_weights_path)) > 0):
            return
        restore_path = open(checkpoint_path).read()
        
        if not os.path.exists(restore_path):
            print('could find a version to retsore', restore_path)
            return
            
        print('restoring from ' + restore_path)
        trainer.restore(restore_path)
    except Exception as e:
        print('error while reloading weights', e)
        raise e

    if os.path.exists(ray_weights_path):
        global version
        version = int(os.listdir(ray_weights_path)[0])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--yaml-settings",
        default='./src/settings.yml'
    )
    parser.add_argument(
        "--run",
        default="DQN",
    )
    args = parser.parse_args()

    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)
    
    def _input(ioctx):
        # We are remote worker or we are local worker with num_workers=0:
        # Create a PolicyServerInput.
        if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
            return PolicyServerInput(
                ioctx,
                SERVER_ADDRESS,
                args.port + ioctx.worker_index -
                (1 if ioctx.worker_index > 0 else 0),
            )
        # No InputReader (PolicyServerInput) needed.
        else:
            return None


    config = {
        # Indicate that the Trainer we setup here doesn't need an actual env.
        # Allow spaces to be determined by user (see below).
        "env": None,
        "observation_space": gym.spaces.Box(float("-inf"), float("inf"), (INPUT_SIZE,)),
        "action_space": gym.spaces.Discrete(yaml_settings["num_of_qualities"]),
        "input": _input,
        "num_workers": yaml_settings["experiments"][0]["num_servers"],
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],
        # Create a "chatty" client/server or not.
        # "callbacks": MyCallbacks if args.callbacks_verbose else None,
        "framework": 'torch',
        "log_level": "ERROR",
        "model": {
            "fcnet_hiddens": [256],
            "fcnet_activation": "relu",
        },
        # "num_atoms": 1,
        # "v_min": -3000.0,
        # "v_max": 60.0,
        # "dueling": False,
        # "lr": 1e-3,
        # "timesteps_per_iteration": 4000,
        # "target_network_update_freq": 8000,
        "min_time_s_per_reporting": 60*10,
        # "min_sample_timesteps_per_reporting": 4000,
        "replay_buffer_config": {
            "capacity": 100000,
            # "replay_batch_size": 32,
            "replay_sequence_length": 1,
        },
        "learning_starts": 20000,
        "train_batch_size": 64,
        # "rollout_fragment_length": 5, # /ray/rllib/agents/trainer.py:155
        "n_step": 5, # https://paperswithcode.com/method/n-step-returns
        "ignore_worker_failures": True,
        "recreate_failed_workers": True
    }

    trainer_cls = get_trainer_class("DQN")
    trainer = trainer_cls(config=config)

    checkpoint_path = CHECKPOINT_FILE.format("DQN")
    ray_weights_path = yaml_settings["experiments"][0]["fingerprint"]["abr_config"]["weights_dir"]

    reload_weights(trainer, checkpoint_path, ray_weights_path)
    save_cpp_model(export_dir=ray_weights_path)
