import os
from os.path import join as pjoin
import random
import argparse
from datetime import datetime

import yaml
import torch
import numpy as np

from models.MAML import MAML, Aggregated_Penalty_MAML, Meta_Separated_Penalty_MAML, Separated_Penalty_MAML, Constrained_MAML


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task_name", type=str, default="sine")
    parser.add_argument("--algo_name", type=str, default="maml")

    # Hyperparameters
    parser.add_argument("--inner_lr", type=float, default=0.001)
    parser.add_argument("--meta_lr", type=float, default=0.001)
    parser.add_argument("--K", type=int, default=5)
    parser.add_argument("--inner_steps", type=int, default=1)
    parser.add_argument("--num_iterations", type=int, default=70000)
    parser.add_argument("--num_eval_tasks", type=int, default=25)
    parser.add_argument("--num_plot_tasks", type=int, default=5)

    # Seed
    parser.add_argument("--seed", default=0)
    
    # Log dir
    parser.add_argument("--logdir", type=str, default="results")
    parser.add_argument("--exp_id", type=str, default="debug")
    
    # for Penalty MAML
    parser.add_argument("--lamb", type=float, default=1)
    parser.add_argument("--num_ve_iters", type=int, default=5)
    parser.add_argument("--alpha", type=float, default=0.04)
    parser.add_argument("--p_norm", type=int, default=2)

    # for Constrained_MAML
    parser.add_argument("--radius", type=float, default=0.05)
    
    # test mode
    parser.add_argument('--mode', type=str, default="skewed")
    
    args = parser.parse_args()

    seed = int(args.seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    results_path = pjoin(args.logdir, args.task_name, args.algo_name, args.exp_id)    
    config = yaml.load(open(pjoin(results_path, "config.yaml"), "r"), Loader=yaml.SafeLoader)

    num_ve_iterations = args.num_ve_iters
    radius = 10000.
    alpha = args.alpha
    lam = 0.
    norm = args.p_norm
    inner_steps = args.inner_steps

    if args.algo_name == "maml":
        algo = MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=inner_steps, 
                    results_path=results_path)
    elif args.algo_name == "apmaml":
        algo = Aggregated_Penalty_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=inner_steps, 
                                       results_path=results_path, num_ve_iterations=num_ve_iterations, lam=lam, norm=norm, alpha=alpha)
    elif args.algo_name == "mspmaml":
        algo = Meta_Separated_Penalty_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K,
                                            inner_steps=inner_steps, results_path=results_path, 
                                            num_ve_iterations=num_ve_iterations, lam=lam, norm=norm, alpha=alpha)
    elif args.algo_name == "spmaml":
        algo = Separated_Penalty_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=inner_steps, 
                                      results_path=results_path, num_ve_iterations=num_ve_iterations, lam=lam, norm=norm, alpha=alpha)
    elif args.algo_name == "cmaml":
        algo = Constrained_MAML(args.task_name, inner_lr=args.inner_lr, meta_lr=args.meta_lr, K=args.K, inner_steps=inner_steps,
                                 results_path=results_path, radius=radius, num_ve_iterations=num_ve_iterations)
    else:
        raise NotImplementedError
    
    algo.model.load_state_dict(torch.load(pjoin(results_path, "model.pt")))
    algo.evaluate(args.num_eval_tasks, K=args.K, n_steps=args.inner_steps, lr=args.inner_lr, mode=args.mode, save=False)
    # algo.plot(args.num_plot_tasks, K=args.K, lr=args.inner_lr)
