import torch
import argparse

import os.path
import sys
import inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
from utils.flows import Vanilla_Flow, TAF, mTAF, gTAF

if torch.cuda.is_available():
    torch.device("cuda")
    device = "cuda"
else:
    torch.device("cpu")
    device = "cpu"


parser = argparse.ArgumentParser()
parser.add_argument("--marginals", type=str, default="normal")
parser.add_argument("--df", type=int, default=2)
parser.add_argument("--num_heavy", type=int, default=8)
parser.add_argument("--model_nr", type=int, default=0)
args = parser.parse_args()

D = 16
setting = "df" + str(args.df) + "h" + str(args.num_heavy)

loss_val_list = []
loss_trn_list = []

if args.marginals=="normal":
    vanilla = Vanilla_Flow(str(D), args.num_heavy, args.df, track_results=True, model_nr=args.model_nr)
    vanilla.save_permutation("models/vanilla_" + setting)
    vanilla.train()
    vanilla.load_model()

elif args.marginals=="TAF":
    student = TAF(str(D), args.num_heavy, args.df, track_results=True, model_nr=args.model_nr)
    student.save_permutation("models/taf_" + setting)
    student.train()
    student.load_model()


elif args.marginals=="mTAF":
    tad = mTAF(str(D), args.num_heavy, args.df, track_results=True, model_nr=args.model_nr)
    # tail estimation is only needed once when doing multiple experiments
    if args.model_nr==1:
        tad.tail_estimation()
    tad.config()
    tad.save_permutation("models/mtaf_" + setting)
    tad.train()
    tad.load_model()

elif args.marginals=="mTAF(fix)":
    tad = mTAF(str(D), args.num_heavy, args.df, track_results=True, model_nr=args.model_nr)
    # tail estimation is only needed once when doing multiple experiments
    if args.model_nr==1:
        tad.tail_estimation()
    tad.config()
    tad.save_permutation("models/mtaffix_" + setting)
    tad.train(lr_df=0.0)
    tad.load_model()


elif args.marginals=="gTAF":
    gtaf = gTAF(str(D), args.num_heavy, args.df, track_results=True, model_nr=args.model_nr)
    gtaf.config()
    gtaf.save_permutation("models/gtaf_" + setting)
    gtaf.train()
    gtaf.load_model()

