import os
import numpy as np
import torch
import gym
import argparse
import d4rl
import utils
import extreme_monge as XMRL
import wandb

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--device_num", default="1")                # Cuda device
    parser.add_argument("--task", default="Critic_OnestepRL")      
    parser.add_argument("--env", default="halfcheetah-expert-v2")   # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=5e4, type=int)       # How often (time steps) we evaluate
    parser.add_argument("--critic_timesteps", default=2e6, type=int)# Num of time steps to update criric
    parser.add_argument("--distance_freq", default = 1e4, type=int)

    parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
    parser.add_argument("--batch_size", default=512, type=int)      # Batch size for both actor and critic
    parser.add_argument("--hidden_dim", default=1024, type=int)     # Hidden layer size for both actor and critic
    parser.add_argument("--discount", default=0.99)                 # Discount factor
    parser.add_argument("--tau", default=0.005)                     # Target network update rate
    parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
    parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
    parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates

    parser.add_argument("--alpha", default=2.5)
    parser.add_argument("--normalize", default=True)
    args = parser.parse_args()

    seed = int(args.seed)
    
    print("---------------------------------------")
    print("Training is started")
    print("---------------------------------------")
    
    models_path = f"./saved_models/{args.env}/"
    if not os.path.exists(models_path):
        os.makedirs(models_path)
        
    if not os.path.exists("./results"):
        os.makedirs("./results")
    
    wandb.init(project=None,
        name=f"{args.task}_{args.env}",
        entity=None,
        reinit=True,
    )
    wandb.run.save()

    env = gym.make(args.env)

    env.seed(seed)
    env.action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])

    kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "hidden_dim": args.hidden_dim,
        "batch_size": args.batch_size,
        "max_action": max_action,
        "discount": args.discount,
        "tau": args.tau,
        "policy_noise": args.policy_noise * max_action,
        "noise_clip": args.noise_clip * max_action,
        "policy_freq": args.policy_freq,
        "alpha": args.alpha
    }

    method = XMRL.XMRL(**kwargs)

    replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
    replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env))
    if args.normalize:
        mean,std = replay_buffer.normalize_states() 
    else:
        mean,std = 0,1

    method.beta.load_state_dict(torch.load(models_path + "Beta_OnestepRL.pt"))
    d4rl_score = utils.eval_policy(method.beta, args.env, seed, mean, std)

    for t in range(int(args.critic_timesteps)):
        log_dict = method.update_critic(replay_buffer, t)
        wandb.log(log_dict, step=t)
        
    torch.save(method.critic.state_dict(), models_path + "Critic_OnestepRL.pt")
    torch.save(method.critic_optimizer.state_dict(), models_path + "Critic_OnestepRL_optimizer.pt")

