import argparse
import time
from algorithms import PolicyGradient, OffPolicyGradient
from data_processors import IdentityDataProcessor
from envs import *
from policies import *
from art import *

def sanitize_filename(filename):
    import re
    return re.sub(r'[^\x20-\x7E]', '', filename)


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
    "--dir",
    help="Directory in which save the results.",
    type=str,
    default=""
)
parser.add_argument(
    "--ite",
    help="How many iterations the algorithm must do.",
    type=int,
    default=100
)
parser.add_argument(
    "--alg",
    help="The algorithm to use.",
    type=str,
    default="pg",
    choices=["pg", "off_pg"]
)
parser.add_argument(
    "--window_length",
    help="The window length for off-policy gradient.",
    type=int,
    default=5
)
parser.add_argument(
    "--var",
    help="The exploration amount.",
    type=float,
    default=1
)
parser.add_argument(
    "--pol",
    help="The policy used.",
    type=str,
    default="linear",
    choices=["linear", "nn", "big_nn"]
)
parser.add_argument(
    "--env",
    help="The environment.",
    type=str,
    default="swimmer",
    choices=["swimmer", "half_cheetah", "reacher", "humanoid", "ant", "hopper", "lqr", "pendulum", "cartpole"]
)
parser.add_argument(
    "--horizon",
    help="The horizon amount.",
    type=int,
    default=100
)
parser.add_argument(
    "--gamma",
    help="The gamma amount.",
    type=float,
    default=1
)
parser.add_argument(
    "--lr",
    help="The lr amount.",
    type=float,
    default=1e-3
)
parser.add_argument(
    "--lr_strategy",
    help="The strategy employed for the lr.",
    type=str,
    default="adam",
    choices=["adam", "constant"]
)
parser.add_argument(
    "--n_workers",
    help="How many parallel cores.",
    type=int,
    default=1
)
parser.add_argument(
    "--batch",
    help="The batch size.",
    type=int,
    default=100
)
parser.add_argument(
    "--clip",
    help="Whether to clip the action in the environment.",
    type=int,
    default=1,
    choices=[0, 1]
)
parser.add_argument(
    "--n_trials",
    help="How many runs of the same experiment to perform.",
    type=int,
    default=1
)
parser.add_argument(
    "--lqr_state_dim",
    help="State dimension for the LQR environment.",
    type=int,
    default=1
)
parser.add_argument(
    "--lqr_action_dim",
    help="Action dimension for the LQR environment.",
    type=int,
    default=2
)
parser.add_argument(
    "--test",
    help="Whether to run in test mode.",
    type=int,
    default=0,
    choices=[0, 1]
)
parser.add_argument(
    "--weight_type",
    help="The type of weight to use in the off-policy gradient.",
    type=str,
    choices=["BH", "MIW", "RTPG"]
)
parser.add_argument(
    "--from_seed",
    help="Initial seed for trials.",
    type=int,
    default=0
)


args = parser.parse_args()

huge = False
if args.pol == "big_nn":
    huge = True
    args.pol = "nn"

if args.alg in ["pg", "off_pg"]:
    if args.pol == "linear":
        args.pol = "gaussian"
    elif args.pol == "nn":
        args.pol = "deep_gaussian"

if args.var < 1:
    string_var = str(args.var).replace(".", "")
else:
    string_var = str(int(args.var))

# Build
base_dir = args.dir

for i in range(args.from_seed, args.n_trials + args.from_seed):
    np.random.seed(i)
    dir_name = f"{args.alg}_{args.ite}_{args.env}_{args.horizon}_{args.lr_strategy}_"
    dir_name += f"{str(args.lr).replace('.', '')}_{args.pol}_batch_{args.batch}_"
    if args.clip:
        dir_name += "clip_"
    else:
        dir_name += "noclip_"

    if args.alg == "off_pg":
        dir_name += f"window_{args.window_length}_{args.weight_type}_"

    """Environment"""
    MULTI_LINEAR = False
    if args.env == "swimmer":
        env_class = Swimmer
        env = Swimmer(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "half_cheetah":
        env_class = HalfCheetah
        env = HalfCheetah(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "reacher":
        if args.costs:
            raise NotImplementedError
        else:
            env_class = Reacher
            env = Reacher(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "humanoid":
        env_class = Humanoid
        env = Humanoid(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "ant":
        env_class = Ant
        env = Ant(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "hopper":
        env_class = Hopper
        env = Hopper(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "lqr":
        env_class = LQR
        env = env_class.generate(
                s_dim=args.lqr_state_dim,
                a_dim=args.lqr_action_dim,
                horizon=args.horizon,
                gamma=args.gamma,
                scale_matrix=0.9
            )
        MULTI_LINEAR = bool(args.lqr_action_dim > 1)
    elif args.env == "pendulum":
        env_class = Pendulum
        env = env_class(horizon=args.horizon, gamma=args.gamma, render=False, clip=bool(args.clip))
        MULTI_LINEAR = True
    elif args.env == "cartpole":
        env_class = ContCartPole
        env = env_class(horizon=args.horizon, gamma=args.gamma)
        MULTI_LINEAR = True
    else:
        raise ValueError(f"Invalid env name.")
    s_dim = env.state_dim
    a_dim = env.action_dim

    """Data Processor"""
    dp = IdentityDataProcessor(dim_feat=env.state_dim)

    """Policy"""
    if args.pol == "linear":
        tot_params = s_dim * a_dim
        pol = OldLinearPolicy(
            parameters=np.zeros(tot_params),
            dim_state=s_dim,
            dim_action=a_dim,
            multi_linear=MULTI_LINEAR
        )
    elif args.pol == "gaussian":
        tot_params = s_dim * a_dim
        pol = LinearGaussianPolicy(
            parameters=np.zeros(tot_params),
            dim_state=s_dim,
            dim_action=a_dim,
            std_dev=np.sqrt(args.var),
            std_decay=0,
            std_min=1e-5,
            multi_linear=MULTI_LINEAR
        )
    elif args.pol in ["nn", "deep_gaussian"]:
        if not huge:
            net = nn.Sequential(
                nn.Linear(s_dim, 32, bias=False),
                nn.Tanh(),
                nn.Linear(32, 32, bias=False),
                nn.Tanh(),
                nn.Linear(32, a_dim, bias=False)
            )
            model_desc = dict(
                layers_shape=[(s_dim, 32), (32, 32), (32, a_dim)]
            )
        else:
            net = nn.Sequential(
                nn.Linear(s_dim, 100, bias=False),
                nn.Tanh(),
                nn.Linear(100, 50, bias=False),
                nn.Tanh(),
                nn.Linear(50, 25, bias=False),
                nn.Tanh(),
                nn.Linear(25, a_dim, bias=False)
            )
            model_desc = dict(
                layers_shape=[(s_dim, 100), (100, 50), (50, 25), (25, a_dim)]
            )
        if args.pol == "nn":
            raise ValueError("Invalid nn policy name.")
        elif args.pol == "deep_gaussian":
            pol = DeepGaussian(
                dim_state=s_dim,
                dim_action=a_dim,
                hidden_neurons=[32, 32],
                param_init=None,
                bias=False,
                activation=torch.tanh,
                init=torch.nn.init.xavier_uniform_,
                std_dev=np.sqrt(args.var),
                std_decay=0,
                std_min=1e-6,
                n_workers=args.n_workers,
            )
        else:
            raise ValueError("Invalid nn policy name.")
        tot_params = pol.tot_params
    else:
        raise ValueError(f"Invalid policy name.")
    dir_name += f"{tot_params}_var_{string_var}"
    dir_name = base_dir + dir_name + "/" + dir_name + f"_trial_{i}"
    dir_name = sanitize_filename(dir_name)
    
    """Algorithm"""
    if args.alg == "pg":
        if args.pol == "gaussian":
            init_theta = [0] * tot_params
        elif args.pol == "deep_gaussian":
            init_theta = pol.get_parameters().detach().numpy()
        else:
            init_theta = np.random.normal(0, 1, tot_params)
        alg_parameters = dict(
            lr=[args.lr],
            lr_strategy=args.lr_strategy,
            estimator_type="GPOMDP",
            initial_theta=init_theta,
            ite=args.ite,
            batch_size=args.batch,
            env=env,
            policy=pol,
            data_processor=dp,
            directory=dir_name,
            verbose=False,
            natural=False,
            checkpoint_freq=100,
            n_jobs=args.n_workers
        )
        alg = PolicyGradient(**alg_parameters)
    elif args.alg == "off_pg":
        if args.pol == "gaussian":
            init_theta = [0] * tot_params
        elif args.pol == "deep_gaussian":
            init_theta = pol.get_parameters().detach().numpy()
        else:
            init_theta = np.random.normal(0, 1, tot_params)
        alg_parameters = dict(
            lr=[args.lr],
            lr_strategy=args.lr_strategy,
            initial_theta=init_theta,
            ite=args.ite,
            batch_size=args.batch,
            env=env,
            policy=pol,
            data_processor=dp,
            directory=dir_name,
            verbose=False,
            natural=False,
            checkpoint_freq=100,
            n_jobs=args.n_workers,
            window_length=args.window_length,
            test=bool(args.test),
            weight_type=args.weight_type
        )
        alg = OffPolicyGradient(**alg_parameters)
    else:
        raise ValueError("Invalid algorithm name.")
    
    print(text2art(f"== {args.alg} TEST on {args.env} =="))
    print(text2art(f"Trial {i}"))
    print(args)
    print(text2art("Learn Start"))
    start = time.perf_counter()
    alg.learn()
    end_time = time.perf_counter() - start
    alg.save_results()
    print(alg.performance_idx)
    print(f"Time elapsed = {end_time}")
