import argparse
import numpy as np
#import torch
import gym
import hirid_env
import sepsis_env
import os
import d4rl

from delphicORL.utils import utils, data, logging, scorers
from delphicORL.algos.offlineRL import TD3_BC

import d3rlpy




def train_CQL(args, env):
    custom_logger, log_dir = logging.setup_logging(args)
    dataset, env = d3rlpy.datasets.get_dataset(args.env)
    train_episodes, test_episodes = data.split_datasets(dataset, test_size=0.2, seed=args.seed)
    # check diff with get_atari_transitions(args.env, fraction=0.1)

    d3rlpy.seed(args.seed)
    env.seed(args.seed)

    cql = d3rlpy.algos.DiscreteCQL(
        learning_rate=args.cql_lr,
        optim_factory=d3rlpy.models.optimizers.AdamFactory(eps=1e-2 / 32),
        batch_size=args.batch_size, #32,
        alpha=float(args.alpha), #4.0
        q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
                           n_quantiles=200),
        target_update_interval=args.target_update_interval,
        use_gpu=True)
    
    scor_dict={'disc_act_acc': scorers.d3rlpy_discrete_action_match_scorer}
    if 'hirid' not in args.env:
        scor_dict['rollout/return_mean'] = scorers.d3rlpy_evaluate_on_environment(env)
    else:
        expert_trajs = data.get_imitation_dataset(args.env)
        train_trajs, _ = data.split_datasets(expert_trajs, test_size=0.2, seed=args.seed)
        episode_infos = [traj.infos for traj in train_trajs]
        scor_dict['rollout/ope_return_mean'] = scorers.d3rlpy_ope_scorer(env, episode_infos)

    cql.fit(train_episodes,
            eval_episodes=test_episodes,
            n_steps=int(args.max_timesteps), #50_000_000 // 4,
            n_steps_per_epoch=int(args.max_timesteps/100),
            scorers=scor_dict,
            custom_logger=logging.D3RLPyLogger(custom_logger),
            save_interval=-1)

    cql.save_policy(os.path.join(log_dir, "final.pt"))
    return


def train_BCQ(args, env):
    custom_logger, log_dir = logging.setup_logging(args)
    dataset, env = d3rlpy.datasets.get_dataset(args.env)
    train_episodes, test_episodes = data.split_datasets(dataset, test_size=0.2, seed=args.seed)

    d3rlpy.seed(args.seed)
    env.seed(args.seed)

    bcq = d3rlpy.algos.DiscreteBCQ(
        learning_rate=args.cql_lr,
        optim_factory=d3rlpy.models.optimizers.AdamFactory(eps=1e-2 / 32),
        batch_size=args.batch_size, #32,
        action_flexibility=float(args.alpha),
        q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
                           n_quantiles=200),
        target_update_interval=args.target_update_interval,
        use_gpu=True)
    
    scor_dict={'disc_act_acc': scorers.d3rlpy_discrete_action_match_scorer}
    if 'hirid' not in args.env:
        scor_dict['rollout/return_mean'] = scorers.d3rlpy_evaluate_on_environment(env)
    else:
        fo_env = gym.make(data.get_fo_env_name(args.env))
        reward_fn = scorers.RewardFn(fo_env.observation_space, env.action_space, 
                                     demonstrations=None,
                                     custom_logger=None)

        reward_fn.load(os.path.join(*log_dir.replace(args.env, data.get_fo_env_name(args.env)) \
                                   .replace(args.command_name, 'ope/dm').replace(f'seed_{args.seed}', 'seed_0') \
                                    .split('/')[:-1]),
                        latest=True)
        
        expert_trajs = data.get_imitation_dataset(args.env)
        train_trajs, _ = data.split_datasets(expert_trajs, test_size=0.2, seed=args.seed)
        episode_infos = [traj.infos for traj in train_trajs]
        scor_dict['rollout/ope_return_mean'] = scorers.d3rlpy_ope_dm_scorer(reward_fn, env, episode_infos) 


    bcq.fit(train_episodes,
            eval_episodes=test_episodes,
            n_steps=int(args.max_timesteps),
            n_steps_per_epoch= int(args.max_timesteps/100),
            scorers=scor_dict,
            custom_logger=logging.D3RLPyLogger(custom_logger),
            save_interval=-1)

    bcq.save_policy(os.path.join(log_dir, "final.pt"))
    return

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    # Experiment
    parser.add_argument("--command_name", default="offline")        # Policy name
    parser.add_argument("--algo", default="cql")        # Algo
    parser.add_argument("--env", default="hopper-medium-v0")        # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e3, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--max_timesteps", default=1e6, type=int)   # Max time steps to run environment
    parser.add_argument("--alpha", default=2.5, type=float)
    parser.add_argument("--normalize", default=True)
    parser.add_argument("--target_update_interval", default=2000)
    parser.add_argument("--cql_lr", default=5e-5)

    # Logging
    parser.add_argument("--extra_log_rep", type=str, default = "")        

    args = parser.parse_args()

    print("---------------------------------------")
    print(f"Policy: {args.command_name}, Env: {args.env}, Seed: {args.seed}")
    print("---------------------------------------")

    env = gym.make(args.env)
    utils.set_seeds(args.seed, env)


    if args.algo == 'cql':
        args.command_name += f'/CQL/alpha_{args.alpha}'
        train_CQL(args, env)
    elif args.algo == 'bcq':
        args.command_name += f'/BCQ/action_flexibility_{args.alpha}'
        train_BCQ(args, env)
    else:
        raise NotImplementedError