import os
import numpy as np
import torch
import gym
import argparse
import d4rl
import utils
import extreme_monge as XMRL
import wandb

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--device_num", default="1")                # Cuda device
    parser.add_argument("--task", default="Beta_OnestepRL")        
    parser.add_argument("--env", default="halfcheetah-expert-v2")   # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e4, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--beta_timesteps", default=5e5, type=int)  # Num of time steps to update beta
    parser.add_argument("--distance_freq", default = 1e4, type=int)

    parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
    parser.add_argument("--batch_size", default=512, type=int)      # Batch size for both actor and critic
    parser.add_argument("--hidden_dim", default=1024, type=int)      # Hidden layer size for both actor and critic
    parser.add_argument("--discount", default=0.99)                 # Discount factor
    parser.add_argument("--tau", default=0.005)                     # Target network update rate
    parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
    parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
    parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates

    parser.add_argument("--alpha", default=2.5)
    parser.add_argument("--normalize", default=True)
    args = parser.parse_args()

    seed = int(args.seed)
    
    print("---------------------------------------")
    print("Training is started")
    print("---------------------------------------")
    
    models_path = f"./saved_models/{args.env}/"
    if not os.path.exists(models_path):
        os.makedirs(models_path)
        
    if not os.path.exists("./results"):
        os.makedirs("./results")
    
    wandb.init(project=None,
        name=f"{args.task}_{args.env}",
        entity=None,
        reinit=True,
    )
    wandb.run.save()

    env = gym.make(args.env)

    env.seed(seed)
    env.action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])

    kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "hidden_dim": args.hidden_dim,
        "batch_size": args.batch_size,
        "max_action": max_action,
        "discount": args.discount,
        "tau": args.tau,
        "policy_noise": args.policy_noise * max_action,
        "noise_clip": args.noise_clip * max_action,
        "policy_freq": args.policy_freq,
        "alpha": args.alpha
    }
    
    method = XMRL.XMRL(**kwargs)

    replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
    replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env))
    if args.normalize:
        mean,std = replay_buffer.normalize_states() 
    else:
        mean,std = 0,1

    for t in range(int(args.beta_timesteps)):
        log_dict = method.update_beta(replay_buffer)
        wandb.log(log_dict, step=t)
    
        if (t+1) % args.eval_freq == 0:
            d4rl_score = utils.eval_policy(method.beta, args.env, seed, mean, std)
            wandb.log({"d4rl_score":d4rl_score}, step=t)
        
    torch.save(method.beta.state_dict(), models_path + "Beta_OnestepRL.pt")

