import os
import argparse
from argparse import RawTextHelpFormatter
import json
import datetime
import torch
import random
import numpy as np

from utils.dataset import Dataset
from utils.make_data import make_data
from utils import log 

def main(cfg):
	if not os.path.isdir("./ckpt"):
		os.mkdir("./ckpt")
	if not os.path.isdir("./ckpt/"+str(cfg['model'])):
		os.mkdir("./ckpt/"+str(cfg['model']))
	if not os.path.isdir("./ckpt/"+str(cfg['model'])+"/model"):
		os.mkdir("./ckpt/"+str(cfg['model'])+"/model")
	if not os.path.isdir("./log"):
		os.mkdir("./log")
	backup_path = "./log/" + datetime.datetime.now().strftime("%Y-%m-%d-%H-%M") + "/"
	if not os.path.isdir(backup_path):
		os.mkdir(backup_path)
	if not os.path.isdir(backup_path+"/model"):
		os.mkdir(backup_path+"/model")
	
	cfg['backup_path'] = backup_path

	log.make_logger(cfg['model'], backup_path)
	log.log(json.dumps(cfg, indent=4))
	

	if cfg['seed'] != -1:
		torch.backends.cudnn.deterministic = True
		torch.backends.cudnn.benchmark = False
		np.random.seed(cfg['seed'])
		random.seed(cfg['seed'])
		torch.manual_seed(cfg['seed'])
		torch.cuda.manual_seed_all(cfg['seed'])
	

	# Load Data
	train_data, train_gt, train_abnormality, valid_data, valid_gt, valid_abnormality, test_data, test_gt, test_abnormality, transform = make_data(cfg)
	cfg['Nn']=train_data[torch.where(((train_abnormality==0)&(train_gt==0)))].size()[0]
	cfg['Nab']=train_data[torch.where((train_abnormality==0)&(train_gt==1))].size()[0]
	cfg['gt']=train_gt
	cfg['N_train'] = len(train_data)
	cfg['N_valid'] = len(valid_data)
	cfg['N_test'] = len(test_data)


	# Generate Dataset
	train_dataset = Dataset(train_data, train_gt, train_abnormality, transform)
	valid_dataset = Dataset(valid_data, valid_gt, valid_abnormality, transform)
	test_dataset = Dataset(test_data, test_gt, test_abnormality, transform)

	# Generate Loader
	def seed_worker(worker_id):
		worker_seed = torch.initial_seed() % 2**32
		np.random.seed(worker_seed)
		random.seed(worker_seed)

	g = torch.Generator()
	g.manual_seed(cfg['seed'])
	drop_last=True
	train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=cfg['batch'], shuffle=True, drop_last=drop_last, num_workers=4, persistent_workers=True,worker_init_fn=seed_worker, generator=g)
	valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=cfg['batch'], shuffle=True, drop_last=drop_last, num_workers=4, persistent_workers=True, worker_init_fn=seed_worker, generator=g)
	test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=cfg['batch'], shuffle=False, drop_last=False, num_workers=4, persistent_workers=True, worker_init_fn=seed_worker, generator=g)



	# Generate Trainer
	if "MemAE" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']==0
		from trainer.T_MemAE import T_MemAE as Trainer
	elif "DSAD" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']!=0
		from trainer.T_DSAD import T_DSAD as Trainer
	elif "DSVDD" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']==0
		from trainer.T_DSAD import T_DSAD as Trainer
	elif "AE" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']==0
		from trainer.T_AE import T_AE as Trainer
	elif "ITSR" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']==0
		from trainer.T_ITSR import T_ITSR as Trainer
	elif "RVAEBFA" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']==0
		from trainer.T_RVAEBFA import T_RVAEBFA as Trainer
	elif "NCAE_UAD" == cfg['model']:
		assert cfg['normal_ct']==0
		assert cfg['abnormal_ct']==0
		from trainer.T_NCAE_UAD import T_NCAE_UAD as Trainer
	else:
		raise ValueError(f"Unknown Model name: {cfg['model']}")

	trainer=Trainer(train_loader, valid_loader, test_loader, cfg)

	# Training
	train_time = 0
	if cfg['train']==True:
		train_time = trainer.train()
	else:
		trainer.load(cfg['load_path'])
		

	# Test
	test_time, AUROC = trainer.test()
	print(f"AUROC: {AUROC}")

	
	return train_time, test_time, AUROC

if __name__=="__main__":
	parser = argparse.ArgumentParser(description="Robust q-lambertw loss function for Anomaly Detection", formatter_class=RawTextHelpFormatter)
	parser.add_argument("--config", type=str, default="./json/DSAD.json", help="Path of json file that include model information")

	args = parser.parse_args()

	with open(args.config, "r") as json_file:
		cfg=json.load(json_file)

	main(cfg)


