import argparse
import json
import os
import time
import torch
import ray
from ray import tune
import pickle
from datetime import datetime
from ray.tune.schedulers import ASHAScheduler, MedianStoppingRule
from ray.rllib.models import ModelCatalog
from ray.rllib.agents.impala.impala import ImpalaTrainer
from ray.rllib.agents.impala.vtrace_torch_policy import VTraceTorchPolicy
from ray.rllib.agents.dqn.r2d2 import R2D2Trainer
from ray.rllib.agents.dqn.r2d2_torch_policy import R2D2TorchPolicy
from actors.ppo.ppo import return_ppo_workflow
from models.ppo_models import (
    PPOAttention,
    FeedForwardPPO,
    RNNModel,
    RNNRewModel,
)
import numpy as np
from actors.r2d2.r2d2 import return_r2d2_workflow

from config import return_config

from actors.ppo.ppopolicy import PPOTorchCustomPolicy
from actors.dqnpolicy import build_custom_dqn_policy
from actors.r2d2.r2d2policy import R2D2TorchCustomPolicy

# from actors.alphastar.alpha_star import AlphaStarTrainer

from custom_scenario import scenario_generator
from action_dist import TorchCategoricalS, TorchCustomCategorical
from callbacks import (
    make_default_fort_attack_callback,
    make_selfplay_sequential,
)
import os

from ray.tune.logger import DEFAULT_LOGGERS

from test_utilities import RolloutSaver, rollout___
from starlette.requests import Request
from ray import serve

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"


# xvfb-run -s \"-screen 0 1400x900x24\" python <your_script.py>
from arguments import create_args, create_parser


if __name__ == "__main__":
    args = create_args()
    assert not (
        args.torch and args.mixed_torch_tf
    ), "Use either --torch or --mixed-torch-tf, not both!"
    ray.init()
    if args.serve:
        serve.start(http_options={"host": "0.0.0.0", "port": 23052})

    if args.algo == "ppo":
        TorchCustomPolicy = PPOTorchCustomPolicy
    elif args.algo == "dqn":
        TorchCustomPolicy = build_custom_dqn_policy(args)
    elif args.algo == "qmix":
        TorchCustomPolicy = None
    elif args.algo == "impala":
        TorchCustomPolicy = VTraceTorchPolicy
    elif args.algo == "r2d2":
        TorchCustomPolicy = R2D2TorchPolicy  # R2D2TorchCustomPolicy
    elif args.algo == "alphastar":
        TorchCustomPolicy = AlphaStarTrainer

    env_config, policy_mapping_fn, policies = scenario_generator(
        args, TorchCustomPolicy
    )

    if args.algo == "ppo":
        custom_training_workflow = return_ppo_workflow(args, TorchCustomPolicy)
    elif args.algo == "dqn":
        custom_training_workflow = return_dqn_workflow(args, TorchCustomPolicy)
    elif args.algo == "qmix":
        custom_training_workflow = return_qmix_workflow(args)
    elif args.algo == "impala":
        custom_training_workflow = ImpalaTrainer
    elif args.algo == "r2d2":
        custom_training_workflow = return_r2d2_workflow(args, TorchCustomPolicy)
    elif args.algo == "alphastar":
        custom_training_workflow = AlphaStarLeagueBuilder

    if args.callback == "sequential":
        custom_callback = make_selfplay_sequential(args)
    elif args.scenario == "fort_attack":
        custom_callback = make_default_fort_attack_callback(args)
    elif args.scenario == "atari":
        custom_callback = make_default_atari_callback(args)

    ModelCatalog.register_custom_model("rnnModel", RNNModel)
    ModelCatalog.register_custom_model("RNNRewModel", RNNRewModel)
    ModelCatalog.register_custom_model("PPOAttention", PPOAttention)
    ModelCatalog.register_custom_model("FeedForwardPPO", FeedForwardPPO)
    ModelCatalog.register_custom_action_dist("five_categorical", TorchCustomCategorical)
    ModelCatalog.register_custom_action_dist("regular_categorical", TorchCategoricalS)

    if args.scheduler == "ASHA":
        scheduler = ASHAScheduler(
            time_attr="training_iteration",
            metric="policy_reward_mean/good_policy",
            mode="max",
            max_t=args.stop_iters,
            grace_period=100,
            reduction_factor=3,
            brackets=2,
        )
    elif args.scheduler == "medianstopping":
        scheduler = MedianStoppingRule(
            time_attr="training_iteration",
            metric="policy_reward_mean/good_policy",
            mode="max",
            grace_period=50,
            min_samples_required=3,
            min_time_slice=0,
            hard_stop=True,
        )
    else:
        scheduler = None

    config = return_config(
        args, env_config, custom_callback, policies, policy_mapping_fn
    )

    stop = {
        "training_iteration": args.stop_iters,
    }
    if args.scenario == "fort_attack":
        def_name = args.experiment_yaml.split("/")[-1].split(".")[0]

        now = datetime.now().strftime("%d:%H:%M:%S")  # current date and time
        name = f"fortattack_{def_name}"

    elif args.scenario == "atari":
        name = "atari"

    if not args.test:

        # MyTrainer
        results = tune.run(
            custom_training_workflow,
            config=config,
            stop=stop,
            num_samples=args.num_samples,
            checkpoint_freq=50,
            restore=args.checkpoint,
            name=name,
            local_dir=args.local_dir,
            scheduler=scheduler,
        )
    else:
        config["evaluation_duration"] = 1
        config["create_env_on_driver"] = True
        config["evaluation_num_workers"] = 0
        config["entropy_coeff_schedule"] = None

        if args.scenario == "fort_attack":
            multiagent = True
        else:
            multiagent = False

        if args.serve:

            @serve.deployment(
                route_prefix="/fortattack-ppo", ray_actor_options={"num_gpus": 0.5}
            )
            class ServePPOModel:
                def __init__(self) -> None:
                    config["num_workers"] = 0
                    config["evaluation_num_workers"] = 0
                    del config["evaluation_interval"]
                    del config["evaluation_num_episodes"]
                    self.trainer = custom_training_workflow(
                        env="custom_env", config=config
                    )
                    with open(args.checkpoint, "rb") as f:
                        model = pickle.load(f)
                        f.close()
                    value = model["worker"]
                    weights = pickle.loads(value)
                    weights = weights["state"]["good_policy"]["weights"]

                    main_state = self.trainer.get_policy("good_policy").get_state()
                    main_state["weights"] = weights
                    self.trainer.get_policy("good_policy").set_state(main_state)

                async def __call__(self, request: Request):
                    json_input = await request.json()
                    a_obs = json_input["obs"]
                    agent_state = json_input["agent_state"]
                    prev_action = json_input["prev_action"]
                    prev_reward = json_input["prev_reward"]
                    policy_id = json_input["policy_id"]
                    concept_update = json_input["concept_update"]
                    do_update = json_input["do_update"]
                    concepts_to_update = json_input["concepts_to_update"]
                    concept_update = torch.tensor(concept_update, dtype=torch.float32)
                    concept_update = concept_update.reshape((1, 1, -1))
                    agent_state[0] = torch.tensor(agent_state[0])
                    agent_state[1] = torch.tensor(agent_state[1])
                    print(a_obs)

                    a_action, p_state, info = self.trainer.compute_single_action_(
                        a_obs,
                        state=agent_state,
                        prev_action=prev_action,
                        prev_reward=prev_reward,
                        policy_id=policy_id,
                        concept_update=concept_update,
                        do_update=do_update,
                        concepts_to_update=concepts_to_update,
                    )
                    a_action = a_action.tolist()
                    p_state[0] = p_state[0].tolist()
                    p_state[1] = p_state[1].tolist()

                    info_ = {}
                    info_["concepts_after_softmax"] = info[
                        "concepts_after_softmax"
                    ].tolist()

                    # for k, v in info.items():
                    #     if type(info[k]) is list:
                    #         info[k] = info[k].tolist()
                    #     else:
                    #         print(info[k])
                    #         info[k] = float(info[k])
                    return {
                        "action": a_action,
                        "state": p_state,
                        "info": info_,
                    }
                    # return {"action": "hello"}

            ServePPOModel.deploy()

            while True:
                time.sleep(30)
                print("ray serve is still going, kill this script to end it")

        agent = custom_training_workflow(env="custom_env", config=config)

        if args.checkpoint:
            with open(args.checkpoint, "rb") as f:
                model = pickle.load(f)
                f.close()

            value = model["worker"]
            weights = pickle.loads(value)
            weights = weights["state"]["good_policy"]["weights"]

            main_state = agent.get_policy("good_policy").get_state()
            main_state["weights"] = weights
            agent.get_policy("good_policy").set_state(main_state)
            # agent.restore(args.checkpoint)

        num_steps = int(args.steps)
        num_episodes = int(args.episodes)

        # Determine the video output directory.
        video_dir = None
        # Allow user to specify a video output path.
        if args.video_dir:
            if not os.path.exists(args.video_dir):
                os.makedirs(args.video_dir)
            video_dir = os.path.expanduser(args.video_dir)

        if not os.path.exists(args.video_dir) or True:
            # Do the actual rollout if it doesn't exist
            with RolloutSaver(
                f"{args.rollout_dir}{name}.pkl",
                args.use_shelve,
                write_update_file=args.track_progress,
                target_steps=num_steps,
                target_episodes=num_episodes,
                save_info=args.save_info,
            ) as saver:
                rollout___(
                    agent,
                    "custom_env",
                    num_steps,
                    num_episodes,
                    saver,
                    not args.render,
                    video_dir,
                    multiagent,
                    args,
                )

        agent.stop()

    ray.shutdown()
