import os
import numpy as np
import torch
import gym
import argparse
import d4rl
import utils
import extreme_monge as XMRL
import wandb

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--device_num", default="1")                # Cuda device
    parser.add_argument("--task", default="Q(s,a)-8f(a)") 
    parser.add_argument("--env", default="halfcheetah-medium-v2")   # OpenAI gym environment name
    parser.add_argument("--seed", default=0, type=int)              # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--eval_freq", default=10e3, type=int)       # How often (time steps) we evaluates
    parser.add_argument("--policy_timesteps", default=1000001, type=int)# Num of time steps to update policy
    parser.add_argument("--distance_freq", default = 1e4, type=int)

    parser.add_argument("--expl_noise", default=0.1)                # Std of Gaussian exploration noise
    parser.add_argument("--batch_size", default=256, type=int)      # Batch size for both actor and critic
    parser.add_argument("--hidden_dim", default=1024, type=int)     # Hidden layer size for both actor and critic
    parser.add_argument("--discount", default=0.99)                 # Discount factor
    parser.add_argument("--tau", default=0.005)                     # Target network update rate
    parser.add_argument("--policy_noise", default=0.2)              # Noise added to target policy during critic update
    parser.add_argument("--noise_clip", default=0.5)                # Range to clip target policy noise
    parser.add_argument("--policy_freq", default=2, type=int)       # Frequency of delayed policy updates

    parser.add_argument("--alpha", default=1)
    parser.add_argument("--normalize", default=True)
    args = parser.parse_args()

    seed = int(args.seed)
    
    wandb.init(project=None,
        name=f"{args.task}_{args.env}",
        entity=None,
        reinit=True,
    )
    wandb.run.save()

    models_path = f"./saved_models/{args.env}/"

    env = gym.make(args.env)
    env.seed(seed)
    env.action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0] 
    max_action = float(env.action_space.high[0])

    kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "hidden_dim": args.hidden_dim,
        "batch_size": args.batch_size,
        "max_action": max_action,
        "max_steps": args.policy_timesteps,
        "discount": args.discount,
        "tau": args.tau,
        "policy_noise": args.policy_noise * max_action,
        "noise_clip": args.noise_clip * max_action,
        "policy_freq": args.policy_freq,
        "alpha": args.alpha
    }
    
    method = XMRL.XMRL(**kwargs)

    method.beta.load_state_dict(torch.load(models_path + "Beta_OnestepRL.pt"))
    method.policy.load_state_dict(torch.load(models_path + "Beta_OnestepRL.pt"))
    method.critic.load_state_dict(torch.load(models_path + "Critic_OnestepRL.pt"))
    method.critic_target.load_state_dict(torch.load(models_path + "Critic_OnestepRL.pt"))
    print("Models uploaded")

    replay_buffer = utils.ReplayBuffer(state_dim, action_dim, max_size=int(10000000))
    replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env))
    if args.normalize:
        mean,std = replay_buffer.normalize_states() 
    else:
        mean,std = 0,1

    for t in range(int(args.policy_timesteps)):
        log_dict = method.conservative_update_critic(replay_buffer, t)
        wandb.log(log_dict, step=t)
        torch.cuda.empty_cache()

        log_dict = method.conservative_update_transport(replay_buffer, t)
        wandb.log(log_dict, step=t)
        torch.cuda.empty_cache()
            
        log_dict = method.update_potential(replay_buffer, t)
        wandb.log(log_dict, step=t)
        torch.cuda.empty_cache()

        if t % args.eval_freq == 0:
            d4rl_score = utils.eval_policy(method.policy, args.env, seed, mean, std)
            wandb.log({"d4rl_score":d4rl_score}, step=t)
            torch.cuda.empty_cache()
        
