import numpy as np
import torch
import torch.nn as nn
import gym
import hirid_env
import sepsis_env
import argparse
import os
import d4rl
import numpy as np

import d3rlpy
from sklearn.model_selection import train_test_split

from delphicORL.utils import utils, logging, scorers, data
from delphicORL.algos.offlineRL.delphic_cql import DelphicCQL
from delphicORL.algos.confounding.worldmodel_learner import WorldModelLearner


DATAPATH = ''


def train_delphic_ORL(args, env):
    custom_logger, log_dir = logging.setup_logging(args)
    dataset, env = d3rlpy.datasets.get_dataset(args.env)
    train_episodes, test_episodes = data.split_datasets(dataset, test_size=0.2, seed=args.seed)

    d3rlpy.seed(args.seed)
    env.seed(args.seed)

    cql = DelphicCQL(
        learning_rate=args.cql_lr,
        optim_factory=d3rlpy.models.optimizers.AdamFactory(eps=1e-2 / 32),
        batch_size=args.batch_size, #32,
        alpha=float(args.alpha), #4.0
        q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
                           n_quantiles=200),
        target_update_interval=args.target_update_interval,
        penalty_hyperparameter=args.penalty_hyp,
        compute_uncertainty = wmlearner.worldmodel.predict_delphic_uncertainty,
        use_gpu=True)
    
    scor_dict={'disc_act_acc': scorers.d3rlpy_discrete_action_match_scorer}
    if 'hirid' not in args.env:
        scor_dict['rollout/return_mean'] = scorers.d3rlpy_evaluate_on_environment(env)
    else:
        expert_trajs = data.get_imitation_dataset(args.env)
        train_trajs, _ = data.split_datasets(expert_trajs, test_size=0.2, seed=args.seed)
        episode_infos = [traj.infos for traj in train_trajs]
        scor_dict['rollout/ope_return_mean'] = scorers.d3rlpy_ope_scorer(env, episode_infos)

    cql.fit(train_episodes,
            eval_episodes=test_episodes,
            n_steps=int(args.max_timesteps),
            n_steps_per_epoch=int(args.max_timesteps/100),
            scorers = scor_dict,
            custom_logger=logging.D3RLPyLogger(custom_logger))

    return


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    # Experiment
    parser.add_argument("--command_name", default="offline_conf")   # Policy name
    parser.add_argument("--env", default="hopper-medium-v0")        # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment

    # TD3
    parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
    parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
    parser.add_argument("--discount", default=0.99)                 # Discount factor
    parser.add_argument("--tau", default=0.005)                     # Target network update rate
    parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
    parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
    parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates
    parser.add_argument("--alpha", default=2.5)
    parser.add_argument("--normalize", default=True)

    # CQL
    parser.add_argument("--target_update_interval", default=2000)
    parser.add_argument("--cql_lr", default=5e-5)

    # Uncertainty penalty
    parser.add_argument("--penalty_hyp", default=1.0, type=float)

    # Logging
    parser.add_argument("--extra_log_rep", type=str, default = "")        

    args = parser.parse_args()
        

    print("---------------------------------------")
    print(f"Policy: {args.command_name}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    env = gym.make(args.env)
    utils.set_seeds(args.seed, env)

    ############################

    expert_trajs = data.get_imitation_dataset(args.env)
    expert_trajs, test_trajs = data.split_datasets(expert_trajs)

    wmlearner = WorldModelLearner(
            observation_space=env.observation_space,
            action_space=env.action_space,
            demonstrations=expert_trajs,
            test_demonstrations=test_trajs,
            lstm=True,
            max_len=20,
            batch_size=8,
            wm_target_dim=1,
            wm_klweight=1e-2,
            no_train_q_func = True
        )

    q_model_path = os.path.join(DATAPATH, args.env, f'seed_{args.seed}')
    print(q_model_path)
    wmlearner.worldmodel = torch.load(os.path.join(q_model_path, np.sort(os.listdir(q_model_path))[-1], 'final.th'))

    args.command_name += f'/penalty_{args.penalty_hyp}'
    train_delphic_ORL(args, env)

    