import numpy as np
import gymnasium as gym
from stable_baselines3 import PPO
import wandb
# from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.logger import configure
from stable_baselines3.envs import HopperGrad, HalfCheetahGrad, Walker2dGrad
    
from stable_baselines3.common.env_util import make_vec_env
import torch
import sys
import argparse
import re

reward_component_list = {
    'HalfCheetah': ['reward_forward', 'reward_ctrl'],
    'Hopper': ['reward_forward', 'reward_survive', 'reward_ctrl'],
    'Walker2d': ['reward_forward', 'reward_ctrl', 'reward_survive'],
}

oracle_weight_list = {
    'HalfCheetah': [.4, 1,],
    'Hopper': [1., 1., 2e-3],
    'Walker2d': [1., 1., 2e-3],
}

other_logging_list = {
    'HalfCheetah': ['x_position', 'control_ok'],
    'Hopper': ['x_position', 'control_ok'],
    'Walker2d': ['x_position', 'control_ok'],
}

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--device", type=str, default='cuda:0')
    parser.add_argument("--weight-type", type=str, default='simplev1') # simple, direct, regulatedv1, regulatedv2, regulatedv3, regulatedv4, oracle
    parser.add_argument("--reward-L2", type=float, default=0.)
    parser.add_argument("--algorithm", type=str, default='if') # if, meta, sparse, oracle, oracle2, random
    parser.add_argument("--env", type=str, default='HalfCheetahSparse-v2')
    parser.add_argument("--share-extractor", type=bool, default=False)
    parser.add_argument("--exploration", type=str, default='single') # none, single, random, maxdis
    parser.add_argument("--explore-frequency", type=int, default=25)
    parser.add_argument("--run-name", type=str, default='run')
    parser.add_argument("--n-epochs-outer", type=int, default=10)
    parser.add_argument("--n-epochs-inner", type=int, default=10)
    parser.add_argument("--n-envs", type=int, default=64)
    parser.add_argument("--inner-loop-frequency", type=int, default=1)
    parser.add_argument("--condition", type=str, default='performance') # none, performance
    parser.add_argument("--explore-choice", type=str, default='explore') # topk, addition, explore
    parser.add_argument("--explore-tmp", type=float, default=10.)
    parser.add_argument("--large-model", type=bool, default=False)
    parser.add_argument("--outer-loop-frequency", type=int, default=5)
    parser.add_argument("--entropy", type=float, default=0.)
    parser.add_argument("--n-steps", type=int, default=512)
    parser.add_argument("--buffer-num", type=int, default=1)
    parser.add_argument("--reset-method", type=str, default='policy') # none, policy, action, logstd, policy_logstd, full
    parser.add_argument("--outer-loop-lr", type=float, default=1e-3)
    parser.add_argument("--start-with-oracle", type=bool, default=False)
    parser.add_argument('--critic-head-num', type=int, default=0)
    parser.add_argument('--extractor-layer', type=int, default=2)
    parser.add_argument('--weight-history', type=str, default='traj') # coeff, traj, space
    parser.add_argument('--use-gradient', action='store_true')
    parser.add_argument('--oracle-weights', type=str, default='')
    parser.add_argument('--env-cl-mode', type=str, default='none') # none, reward_density, episode_length, objective_difficulty
    return parser.parse_args()

args = parse_args()
seed = args.seed
device = args.device
weight_type = args.weight_type
reward_L2 = args.reward_L2
algorithm = args.algorithm
env_name = args.env
share_extractor = args.share_extractor
exploration = args.exploration
explore_frequency = args.explore_frequency
outer_loop_frequency = args.outer_loop_frequency
run_name = args.run_name
n_epochs_outer = args.n_epochs_outer
use_gradient = args.use_gradient

if env_name == 'Minitaur4Sparse-v1':
    env = make_vec_env('MinitaurSparse-v1', n_envs=args.n_envs, seed=seed)
    # test_env = make_vec_env('MinitaurSparse-v1', n_envs=4, seed=0)
else:
    if args.env_cl_mode == 'none':
        env = make_vec_env(env_name, n_envs=args.n_envs, seed=seed)
    else:
        env = make_vec_env(env_name, n_envs=args.n_envs, seed=seed, env_kwargs={'mode': args.env_cl_mode})
        env.env_method("switch_difficulty") #initialize corresponding difficulty level

parts = re.findall('[A-Z][^A-Z]*', env_name)
# env_base = env_name.split('-')[0][:-6]
env_base = "".join(parts[:-1]) if len(parts) > 0 else env_name
reward_components = reward_component_list[env_base]
n_reward_components = len(reward_components) + 1
full_run_name = f"{run_name}-{algorithm}-{weight_type}-{exploration}-{seed}"

wandb.init(
    project="project_name",
    config=args,
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=False,  # optional
    name=f"{run_name}-{algorithm}-{weight_type}-{exploration}",
)

tmp_path = wandb.run.dir
new_logger = configure(tmp_path, ["wandb"], )


# Create the model with appropriate parameters for longer training
n_steps = args.n_steps
if args.critic_head_num == 1:
    policy_kwargs = dict(num_heads=1) # 1 head in total. for shaped reward.
elif args.critic_head_num == 2:
    policy_kwargs = dict(num_heads=2) # 2 heads in total. one for shaped reward, one for task reward
else:
    policy_kwargs = dict(num_heads=len(reward_components)+1)
    # n+1 heads in total. one for task reward, n for each reward component.

n_epochs=args.n_epochs_inner
vf_coef=0.5
ent_coef=args.entropy
max_grad_norm=0.5
learning_rate=3.0e-4


log_interval = 25
if env_name[:6] == 'Walker':
    outer_loop_frequency = 15
    explore_frequency *= 2
    n_epochs=20
    log_interval = 20

if args.algorithm == 'meta':
    store_old_models=True
else:
    store_old_models=False
    
model = PPO(
    "MlpPolicy",
    env,
    policy_kwargs=policy_kwargs,
    reward_components=reward_components,
    record_keys=reward_components+other_logging_list[env_base],
    verbose=1,
    n_steps=n_steps, 
    tensorboard_log=f"runs/{full_run_name}",
    batch_size=n_steps, # minibatch size
    seed=seed,
    device=device, # TODO
    n_epochs=n_epochs,
    store_old_models=store_old_models,
    buffer_num=args.buffer_num, # int(args.outer_loop_frequency/self.inner_loop_frequency),
    vf_coef=vf_coef,
    ent_coef=ent_coef,
    max_grad_norm=max_grad_norm,
    learning_rate=learning_rate,
    critic_head_num=args.critic_head_num,
)

model.set_logger(new_logger)
from outer_loop.optimizer import SparseOuterLoop, OracleOuterLoop, OracleOuterLoop2, RandomOuterLoop, MetaGradientOuterLoop, ImplicitFunctionOuterLoop
from outer_loop.weight_net import SimpleRewardWeightv1, SimpleRewardWeightv2, DirectRewardWeight, RegulatedRewardWeightv1, RegulatedRewardWeightv2, \
    RegulatedRewardWeightv3, RegulatedRewardWeightv4, OracleRewardWeight, RegulatedRewardWeightv5, RegulatedRewardWeightv6, RegulatedRewardWeightv7, RegulatedRewardWeightv8
from outer_loop.exploration import SinglePrediction, RandomExploration, MaxDistanceExploration

if algorithm == 'if':
    outer_loop_class = ImplicitFunctionOuterLoop
elif algorithm == 'meta':
    outer_loop_class = MetaGradientOuterLoop
elif algorithm == 'sparse':
    outer_loop_class = SparseOuterLoop
elif algorithm == 'oracle':
    outer_loop_class = OracleOuterLoop
elif algorithm == 'default':
    outer_loop_class = OracleOuterLoop2
elif algorithm == 'random':
    outer_loop_class = RandomOuterLoop
else:
    raise ValueError(f"Algorithm {algorithm} not supported")

lower = 0.0
upper = 1.0
total_timesteps = 20_000_000
explore_reward_dim = n_reward_components
if weight_type == 'simplev1':
    weight_net_class = SimpleRewardWeightv1
    total_timesteps=10_000_000
elif weight_type == 'simplev2':
    weight_net_class = SimpleRewardWeightv2
    lower = -1.0
elif weight_type == 'direct':
    weight_net_class = DirectRewardWeight
elif weight_type == 'regulatedv1':
    weight_net_class = RegulatedRewardWeightv1
elif weight_type == 'regulatedv2':
    weight_net_class = RegulatedRewardWeightv2
elif weight_type == 'regulatedv3':
    weight_net_class = RegulatedRewardWeightv3
    total_timesteps=5_000_000
elif weight_type == 'regulatedv4':
    weight_net_class = RegulatedRewardWeightv4
    lower = -1.0
elif weight_type == 'regulatedv5':
    weight_net_class = RegulatedRewardWeightv5
    total_timesteps=5_000_000
elif weight_type == 'regulatedv6':
    weight_net_class = RegulatedRewardWeightv6
    lower = -1.0
elif weight_type == 'regulatedv7':
    weight_net_class = RegulatedRewardWeightv7
    explore_reward_dim = n_reward_components-1
    total_timesteps=5_000_000
elif weight_type == 'regulatedv8':
    weight_net_class = RegulatedRewardWeightv8
    explore_reward_dim = n_reward_components-1
    lower = -1.0
elif weight_type == 'oracle':
    weight_net_class = OracleRewardWeight
else:
    raise ValueError(f"Weight type {weight_type} not supported")

if exploration == 'none':
    exploration_module = None
elif exploration == 'single':
    exploration_module = SinglePrediction(weight_dim=explore_reward_dim, hidden_dim=64, device=device, lower=lower, upper=upper)
elif exploration == 'random':
    exploration_module = RandomExploration(weight_dim=explore_reward_dim, hidden_dim=64, device=device, lower=lower, upper=upper)
elif exploration == 'maxdis':
    exploration_module = MaxDistanceExploration(weight_dim=explore_reward_dim, hidden_dim=64, device=device, lower=lower, upper=upper)
else:
    raise ValueError(f"Exploration {exploration} not supported")

if args.weight_history == 'space':
    obs = []
    for i in range(1000):
        observation_sample = env.observation_space.sample()
        obs.append(observation_sample)
    obs_tensor = torch.tensor(obs, device=device, dtype=torch.float32)
else:
    obs_tensor = None

weight_list = oracle_weight_list.get(env_base, [])
if len(args.oracle_weights) > 0:
    weight_list = []
    ss = args.oracle_weights.split(' ')
    for s in ss:
        weight_list.append(float(s))
print(weight_list)

outer_loop = outer_loop_class(
    model,
    n_reward_components=n_reward_components, # for mujoco-halfcheetah
    learning_rate=1e-3,
    n_epochs=n_epochs_outer,
    device=device,
    save_path=wandb.run.dir,
    reward_L2=reward_L2,
    weight_net_class=weight_net_class,
    share_extractor=share_extractor,
    exploration_module=exploration_module,
    explore_choice=args.explore_choice,
    explore_tmp=args.explore_tmp,
    oracle_weight=weight_list,
    # oracle_weight=oracle_weight_list.get(env_base, []),
    reset_method=args.reset_method,
    log_episode_reward=log_interval,
    lr=args.outer_loop_lr,
    start_with_oracle=args.start_with_oracle,
    extractor_layer=args.extractor_layer,
    weight_history=args.weight_history,
    obs_space_tensor=obs_tensor,
    use_gradient=use_gradient,
)

model.add_reward_weight_forward(outer_loop.forward)

outer_loop.learn(
    total_timesteps=total_timesteps,
    n_steps=n_steps*args.n_envs,
    inner_loop_freq=args.inner_loop_frequency,
    outer_loop_freq=outer_loop_frequency,
    progress_bar=True,
    test_env=None,
    explore_frequency=explore_frequency,
    condition=args.condition,
)

env.close()
wandb.finish()