import os
import argparse
import time
import numpy as np
import os
import random

from distutils.util import strtobool
from d3rlpy.algos import RandomPolicy
from d3rlpy.datasets import MDPDataset
from d3rlpy.preprocessing.scalers import StandardScaler
from d3rlpy.preprocessing.action_scalers import MinMaxActionScaler
from d3rlpy.preprocessing.reward_scalers import MinMaxRewardScaler
from d3rlpy.models.encoders import VectorEncoderFactory
from d3rlpy.algos import DiscreteCQL, DiscreteSAC, DiscreteBCQ, DiscreteBC, DoubleDQN, DQN
from torch.distributions.categorical import Categorical

from torch.utils.tensorboard import SummaryWriter
import torch
import subprocess
import os
# get current folder
import sys
# folder = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
# print(sys.path)
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# output = os.system(bashCommand)
# output, error = process.communicate()
# output = (q.decode()).replace("\n","")
# print("Latest trained model:", output)
parser = argparse.ArgumentParser("Testing offline RL")
parser.add_argument('--data', type=str, default="../trajectories/offline_trajectories_full.h5")
parser.add_argument('--model', type=str, default="{}".format('model_sklled-smoke.pt'))
parser.add_argument('--algo', type=str, default="DiscreteSAC")
parser.add_argument('--use_gpu', type=bool, default=False)
parser.add_argument('--weight_temp', type=float, default=100.0)
parser.add_argument('--n_epochs', type=int, default=50)
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help="if toggled, this experiment will be tracked with Weights and Biases")
# parse in the hyperparameters
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--actor_learning_rate', type=float, default=0.00001)
parser.add_argument('--critic_learning_rate', type=float, default=0.00001)
parser.add_argument('--temp_learning_rate', type=float, default=0.0001)
parser.add_argument('--n_critics', type=int, default=6)
parser.add_argument('--run_name', type=str, default=None)

parser.add_argument("--disable-cartesian-product", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
            help="if toggled, cartesian product will be disabled")
parser.add_argument("--enable-bushy", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
        help="if toggled, allow bushy join trees")
# seed
parser.add_argument('--seed', type=int, default=1,
                        help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--total-timesteps", type=int, default=int(1),
        help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=8e-4,
    help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=4,
    help="the number of parallel game environments")
parser.add_argument("--num-steps", type=int, default=128,
    help="the number of steps to run in each environment per policy rollout")
# args = parser.parse_args()
# Read data

args_parse = parser.parse_args()
algo = args_parse.algo
# dataset = MDPDataset(
#     observations=data["observations"],
#     actions=data["actions"],
#     rewards=data["rewards"],
#     terminals=data["terminals"],
#     discrete_action=False,
# )

# data = np.load(args_parse.data)
# dataset = MDPDataset(
#     observations=data["observations"],
#     actions=data["actions"],
#     rewards=data["rewards"],
#     terminals=data["terminals"],
#     discrete_action=False,
# )

dataset = MDPDataset.load(args_parse.data)
# setup CQL algorithm
batch_size = args_parse.batch_size
actor_learning_rate = args_parse.actor_learning_rate
critic_learning_rate = args_parse.critic_learning_rate
temp_learning_rate = args_parse.temp_learning_rate
n_critics = args_parse.n_critics
n_epochs = args_parse.n_epochs
weight_temp = args_parse.weight_temp
use_gpu = args_parse.use_gpu

if algo == "CQL" or algo == "DiscreteCQL":
    alg = DiscreteCQL(
        scaler=StandardScaler(dataset),
        use_gpu=False,
    )
elif algo == "DiscreteBC":
    alg = DiscreteBC(
        scaler=StandardScaler(dataset),
        use_gpu=use_gpu,
        actor_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
        critic_encoder_factory=VectorEncoderFactory(use_batch_norm = True, dropout_rate = 0.5, hidden_units=[16, 16]),
    )
elif algo == "DiscreteSAC":
    # https://d3rlpy.readthedocs.io/en/v0.41/references/generated/d3rlpy.algos.DiscreteSAC.html#
    alg = DiscreteSAC(
        scaler=StandardScaler(dataset),
        critic_learning_rate=critic_learning_rate,
        actor_learning_rate=actor_learning_rate,
        temp_learning_rate=temp_learning_rate,
        n_critics=n_critics,
        n_epochs=n_epochs,
        use_gpu=False,
    )
elif algo == "DiscreteBCQ":
    alg = DiscreteBCQ(
        scaler=StandardScaler(dataset),
        n_critics = n_critics,
        use_gpu=use_gpu,
    )
elif algo == "DoubleDQN":
    alg = DoubleDQN(
        scaler=StandardScaler(dataset),
        use_gpu=use_gpu,
    )
elif algo == "DQN":
    alg = DQN(
        scaler=StandardScaler(dataset),
        use_gpu=use_gpu,
)
elif algo == "RandomPolicy":
    alg = RandomPolicy(
    distribution='uniform',
    action_scaler=MinMaxActionScaler(minimum=1, maximum=100),
    normal_std=1.0
)
else:
    raise NotImplementedError("Algorithm not implemented or not suitable for discrete actions or haven't been implemented!")

# Build the model
alg.build_with_dataset(dataset)
if algo != "RandomPolicy":
    model = args_parse.model
    alg.load_model(model)
    print("Model loaded:", model)

# build gym environment
random.seed(args_parse.seed)
np.random.seed(args_parse.seed)
torch.manual_seed(args_parse.seed)
torch.backends.cudnn.deterministic = args_parse.torch_deterministic
split = get_query_ids_split()


# track
# get timestamp
import datetime
now = datetime.datetime.now()
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")

run_name = args_parse.run_name if args_parse.run_name else f"{algo}_{timestamp}"

writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s"
        % ("\n".join([f"|{key}|{value}|" for key, value in vars(args_parse).items()])),
    )

import wandb
wandb.init(
            project='offline-rl-online-eval',
            config=vars(args_parse),
            name=run_name,
            save_code=True,
        )
wandb.tensorboard.patch(root_logdir="runs")



train_env = make_env(args_parse.seed, split["train"], args_parse.disable_cartesian_product, args_parse.enable_bushy)()

eval_envs_map = {
        key: make_env(seed=args_parse.seed+i, query_ids=split[key],
                    disable_cartesian_product=args_parse.disable_cartesian_product,
                    enable_bushy=args_parse.enable_bushy)()
            for i, key in enumerate(["train", "val", "test"])
    }

# Start the game
obs, info = train_env.reset()
poss_actions_mask = info["action_mask"]
best_agent_stats = {"val_returns": -float("inf")}
poss_actions_mask = info["action_mask"]
for global_step in range(args_parse.total_timesteps):
    print("Step:", global_step)
    def agent(obs, actions_mask):
        # import pdb; pdb.set_trace()
        action_len = len(actions_mask)

        actions = np.arange(action_len)
        # resehape it to (1, dim)
        actions = actions.reshape(1, -1)
        try:
            action_values = alg.predict_value(obs, actions) # discrete distribution
            action_values = action_values.reshape(-1)
            action_values[~actions_mask] = -float("inf")
            # Make action_values to tensor
            action_values = torch.Tensor(action_values).to(device)
            policy_dist = Categorical(logits=action_values)
            # action = policy_dist.sample()
            action_probs = policy_dist.probs
            action = torch.argmax(action_probs).cpu().numpy()
        except:
            # for BC, no value estimation
            action = alg.predict(obs) # discrete distribution
            # if action is masked out, choose random action
            if actions_mask[action] == False:
                action = np.random.choice(np.where(actions_mask)[0])

        return action

    eval_results = evaluate_agent(eval_envs_map, agent, 'cpu', writer, global_step)

wandb.finish()