import os
from os.path import join as pjoin
import random
import argparse
from datetime import datetime

import yaml
import torch
import numpy as np

from models.MAML import MAML, Aggregated_Penalty_MAML, Meta_Separated_Penalty_MAML, Separated_Penalty_MAML, Constrained_MAML

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task_name", type=str, default="sine")
    parser.add_argument("--task_dist", type=str, default="skewed")
    parser.add_argument("--algo_name", type=str, default="maml")

    # Hyperparameters
    parser.add_argument("--inner_lr", type=float, default=0.001)
    parser.add_argument("--meta_lr", type=float, default=0.001)
    parser.add_argument("--K", type=int, default=5)
    parser.add_argument("--inner_steps", type=int, default=1)
    parser.add_argument("--num_iterations", type=int, default=50000)
    parser.add_argument("--num_eval_tasks", type=int, default=100)
    parser.add_argument("--num_plot_tasks", type=int, default=5)

    # Seed
    parser.add_argument("--seed", default=0)
    
    # Log dir
    parser.add_argument("--logdir", type=str, default="results")
    parser.add_argument("--exp_id", type=str, default="debug")
    
    # for Penalty MAML
    parser.add_argument("--lamb", type=float, default=1)
    parser.add_argument("--num_ve_iters", type=int, default=5)
    parser.add_argument("--alpha", type=float, default=0.04)
    parser.add_argument("--p_norm", type=int, default=2)

    # for Constrained_MAML
    parser.add_argument("--radius", type=float, default=0.05)

    
    args = parser.parse_args()

    seed = int(args.seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # results_path = pjoin(args.logdir, args.task_name, args.algo_name, args.exp_id, datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
    results_path = pjoin(args.logdir, args.task_name, args.task_dist, args.algo_name, args.seed)
    if not os.path.isdir(results_path):
        os.makedirs(results_path)
        
    config = {}
    for key, value in vars(args).items():
        config[key] = value

    with open(pjoin(results_path, "config.yaml"), 'w') as f:
        yaml.dump(config, f)

    num_ve_iterations = args.num_ve_iters
    radius = args.radius
    alpha = args.alpha
    lam = args.lamb
    norm = args.p_norm

    if args.algo_name == "maml":
        algo = MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=args.inner_steps, 
                    results_path=results_path, mode=args.task_dist)
    elif args.algo_name == "apmaml":
        algo = Aggregated_Penalty_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=1, 
                                       results_path=results_path, num_ve_iterations=num_ve_iterations, lam=lam, norm=norm, alpha=alpha, mode=args.task_dist)
    elif args.algo_name == "mspmaml":
        algo = Meta_Separated_Penalty_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K,
                                            inner_steps=1, results_path=results_path, 
                                            num_ve_iterations=num_ve_iterations, lam=lam, norm=norm, alpha=alpha)
    elif args.algo_name == "spmaml":
        algo = Separated_Penalty_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=1, 
                                      results_path=results_path, num_ve_iterations=num_ve_iterations, lam=lam, norm=norm, alpha=alpha)
    elif args.algo_name == "cmaml":
        algo = Constrained_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=1,
                                 results_path=results_path, radius=radius, num_ve_iterations=num_ve_iterations, mode=args.task_dist)
    else:
        raise NotImplementedError
    
    algo.train(args.num_iterations)
    algo.evaluate(args.num_eval_tasks, K=args.K, n_steps=10, lr=args.inner_lr, mode=args.task_dist)
    # algo.plot(args.num_plot_tasks, K=args.K, lr=args.inner_lr)
