import argparse
from datetime import datetime
import torch
import os

from models import *

from train.train_data import train_data
from utils.utils import make_reproducibility
from data_utils.dataloader import DataLoader

parser = argparse.ArgumentParser(description="VAE simulation")
parser.add_argument('--model',         type=str,   default='pvae',    help='Model type')
parser.add_argument('--nu',         type=float,   default=10.0,    help='Model hyperparameter nu')
parser.add_argument('--m_dim',     type=int, default=1,    help='Latent dimension m')
parser.add_argument('--recon_sigma',    type=float, default=1,    help='sigma value in decoder')
parser.add_argument('--reg_weight',     type=float, default=0.01,    help='weight for regularizer term (beta)')
parser.add_argument('--simul_type',     type=str,   default="t1",      help="Simulation distribution : [dist_type][dim] e.g.) t1, p1, t2, data..")

parser.add_argument('--epochs',         type=int,   default=200,    help='Train epoch')
parser.add_argument('--num_layers',     type=int,   default=64,     help='Number of nodes in layers of neural networks')
parser.add_argument('--batch_size',     type=int,   default=256,   help='Batch size')
parser.add_argument('--lr',             type=float, default=1e-3,   help='Learning rate')
parser.add_argument('--eps',            type=float, default=1e-8,   help="Epsilon for Adam optimizer")
parser.add_argument('--weight_decay',   type=float, default=1e-4,   help='Weight decay')

parser.add_argument('--seed', type=int,   default=1,      help="Seed for sampling train data")
parser.add_argument('--cuda_num',       type=int,   default=3,     help="GPU device number")
parser.add_argument('--exp_name',       type=str,   default='',     help="Specific exp name")
parser.add_argument('--patience',       type=int,   default=100,     help="Patience for Early stopping")

args = parser.parse_args()
cuda_num = args.cuda_num


device = torch.device(f'cuda:{cuda_num}' if torch.cuda.is_available() else "cpu")
seed = args.seed
make_reproducibility(seed)

## load data ##
simul_type = args.simul_type
dataloader = DataLoader(simul_type, seed=seed, batch_size = args.batch_size, device=device)

n_dim, train_dataloader, val_dataloader, test_dataloader = dataloader.load_data()
m_dim = args.m_dim

# Model list ##
model_list =[]
if args.model == 'vae':
    model = VAE.VAE(n_dim=n_dim, m_dim=m_dim, recon_sigma=args.recon_sigma, reg_weight=args.reg_weight, device=device).to(device)
elif args.model == 'pvae':
    model = ParetoVAE.ParetoVAE(nu = args.nu, n_dim=n_dim, m_dim=m_dim, recon_sigma=args.recon_sigma, num_hidden = args.num_layers , reg_weight=args.reg_weight, device=device).to(device)
elif args.model == 't3vae':
    model = t3VAE.t3VAE(nu = args.nu, n_dim=n_dim, m_dim=m_dim, recon_sigma=args.recon_sigma, reg_weight=args.reg_weight, device=device).to(device)
elif args.model == 'lvae':   
    model = LVAE.LVAE(n_dim=n_dim, m_dim=m_dim, recon_sigma=args.recon_sigma, reg_weight=args.reg_weight, device=device).to(device)
elif args.model == 'ae':   
    model = AE.AE(n_dim=n_dim, m_dim=m_dim, recon_sigma=args.recon_sigma, reg_weight=args.reg_weight, device=device).to(device)

else:
    raise NotImplementedError

best_model = train_data(
    model, device,
    train_dataloader, val_dataloader, test_dataloader,
    args.epochs, args.lr, args.eps, args.weight_decay, args.batch_size, args.seed, args)
save_dir = f"{model.model_name}/checkpoints"
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, f"{model.model_name}_mdim:{args.m_dim}_lr:{args.m_dim}_layers:{args.num_layers}_w:{args.reg_weight}_seed:{args.seed}.pt")
torch.save(model.state_dict(), save_path)
print(f"Saved {model.__class__.__name__} to {save_path}")

