import argparse, os, torch, datetime
from libs.data import TabularDataset
from libs.eval import *
from libs.model import getmodel_ad
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

parser = argparse.ArgumentParser()
parser.add_argument("--gpu_id", type=int, default=0, help="gpu index")
parser.add_argument("--seed", type=int, default=0, help="seed")
parser.add_argument("--savepath", type=str, default="results", help="path to save the results")

parser.add_argument("--dataname", type=str, default="2_annthyroid")
parser.add_argument("--aug", type=str, default="attention_mask")
parser.add_argument("--k_aug", type=float, default=0.3)
parser.add_argument("--fusion_method", type=str, default="concat", choices=['concat', 'average_concat', 'average'])
parser.add_argument("--fusion_num", type=int, default=1)
parser.add_argument("--score_method", type=str, default="cosine", choices=['cosine', 'l2'])
parser.add_argument("--results_fname", type=str, default="results_log")
parser.add_argument("--batch_size", type=int, default=1)

args = parser.parse_args()
print(args)

tasktype = "anomaly"
torch.manual_seed(args.seed)

savepath = os.path.join(args.savepath, f'seed={args.seed}')
if not os.path.exists(savepath):
    os.makedirs(savepath)
save_model = os.path.join(savepath, f'data={args.dataname}..k_aug={args.k_aug}..fusion={args.fusion_method}_{args.fusion_num}_{args.score_method}.pth')
args.save_model = save_model
save_logs = os.path.join(args.savepath, f"{args.results_fname}_seed{args.seed}.csv")


train = True
print(save_model)
if os.path.exists(save_model):    
    print("Already done!", save_model)
    train = False
    
# Main 
if train:
    torch.cuda.set_device(args.gpu_id)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    env_info = '{0}:{1}'.format(os.uname().nodename, args.gpu_id)
    
    dataset = TabularDataset(args.dataname, tasktype, device=device, seed=args.seed)    
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = dataset._indv_dataset()
    print(X_train.shape)
    
    feature_dim = X_train.shape[1] 
    model = getmodel_ad(tasktype, dataset, feature_dim, feature_dim, args, device)    
    model.fit(X_train, y_train, X_val, y_val)    
    
    model.load_model(save_model)    
    recon_error = model.predict(X_val)
    mse_auc, mse_ap, mse_f1 = calcuate_metric_ad(y_val, recon_error)
    print('BEST TEST AUC-PR: %.3f, TEST AUROC: %.3f, TEST F1: %.3f' % (mse_ap, mse_auc, mse_f1))
    
    log_results_to_csv(save_logs, args, mse_auc, mse_ap, mse_f1)
    
  